TF-Lite Model API Examples

This demonstrates how to use the TF-Lite Model package.

NOTES:

  • Click here: Open In Colab to run this example interactively in your browser

  • Refer to the Notebook Examples Guide for how to run this example locally in VSCode

Install MLTK Python Package

# Install the MLTK Python package (if necessary)
!pip install --upgrade silabs-mltk

Import Python Packages

# Import the standard Python packages used by the examples
import os
import urllib
import shutil
import tempfile

Download .tflite model file

A .tflite model file is required to run these examples.
The following code downloads a model.

NOTE: Update TFLITE_MODEL_URL or tflite_path to point to your model if necessary

# Use .tflite mode found here:
# https://github.com/siliconlabs/mltk/tree/master/mltk/utils/test_helper/data/
# NOTE: Update this URL to point to your model if necessary
TFLITE_MODEL_URL = 'https://github.com/siliconlabs/mltk/raw/master/mltk/utils/test_helper/data/image_example1.tflite'

# Download the .tflite file and save to the temp dir
tflite_path = os.path.normpath(f'{tempfile.gettempdir()}/image_example1.tflite')
with open(tflite_path, 'wb') as dst:
    with urllib.request.urlopen(TFLITE_MODEL_URL) as src:
        shutil.copyfileobj(src, dst)

Example 1: Load .tflite and print summary

This example loads .tflite model file and prints a summary:

# Import the TfliteModel class
from mltk.core import TfliteModel 

# Load the .tflite
tflite_model = TfliteModel.load_flatbuffer_file(tflite_path)

# Generate a summary of the .tflite
summary = tflite_model.summary()

# Print the summary to the console
print(summary)
+-------+-----------------+-----------------+-----------------+-----------------------------------------------------+
| Index | OpCode          | Input(s)        | Output(s)       | Config                                              |
+-------+-----------------+-----------------+-----------------+-----------------------------------------------------+
| 0     | conv_2d         | 96x96x1 (int8)  | 48x48x24 (int8) | Padding:same stride:2x2 activation:relu             |
|       |                 | 3x3x1 (int8)    |                 |                                                     |
|       |                 | 24 (int32)      |                 |                                                     |
| 1     | average_pool_2d | 48x48x24 (int8) | 24x24x24 (int8) | Padding:valid stride:2x2 filter:2x2 activation:none |
| 2     | conv_2d         | 24x24x24 (int8) | 11x11x16 (int8) | Padding:valid stride:2x2 activation:relu            |
|       |                 | 3x3x24 (int8)   |                 |                                                     |
|       |                 | 16 (int32)      |                 |                                                     |
| 3     | conv_2d         | 11x11x16 (int8) | 9x9x24 (int8)   | Padding:valid stride:1x1 activation:relu            |
|       |                 | 3x3x16 (int8)   |                 |                                                     |
|       |                 | 24 (int32)      |                 |                                                     |
| 4     | average_pool_2d | 9x9x24 (int8)   | 4x4x24 (int8)   | Padding:valid stride:2x2 filter:2x2 activation:none |
| 5     | reshape         | 4x4x24 (int8)   | 384 (int8)      | Type=none                                           |
|       |                 | 2 (int32)       |                 |                                                     |
| 6     | fully_connected | 384 (int8)      | 3 (int8)        | Activation:none                                     |
|       |                 | 384 (int8)      |                 |                                                     |
|       |                 | 3 (int32)       |                 |                                                     |
| 7     | softmax         | 3 (int8)        | 3 (int8)        | Type=softmaxoptions                                 |
+-------+-----------------+-----------------+-----------------+-----------------------------------------------------+

Example 2: Iterate the model layers

This example loads .tflite model file and iterates through the layers of the model

# Import the TfliteModel class
from mltk.core import TfliteModel 

# Load the .tflite
tflite_model = TfliteModel.load_flatbuffer_file(tflite_path)

# Iterate over each layer in the .tflite
for layer in tflite_model.layers:
    print(f'Layer {layer.name}:\n\tInputs={",".join(str(x) for x in layer.inputs)}\n\tOutputs={",".join(str(x) for x in layer.outputs)}')
Layer op0-conv_2d:
	Inputs=conv2d_input_int8, dtype:int8, shape:1x96x96x1,image_example1/conv2d/Conv2D, dtype:int8, shape:24x3x3x1,image_example1/conv2d/BiasAdd/ReadVariableOp/resource, dtype:int32, shape:24
	Outputs=image_example1/conv2d/Relu;image_example1/conv2d/BiasAdd;image_example1/conv2d_2/Conv2D;image_example1/conv2d/Conv2D;image_example1/conv2d/BiasAdd/ReadVariableOp/resource, dtype:int8, shape:1x48x48x24
Layer op1-average_pool_2d:
	Inputs=image_example1/conv2d/Relu;image_example1/conv2d/BiasAdd;image_example1/conv2d_2/Conv2D;image_example1/conv2d/Conv2D;image_example1/conv2d/BiasAdd/ReadVariableOp/resource, dtype:int8, shape:1x48x48x24
	Outputs=image_example1/average_pooling2d/AvgPool, dtype:int8, shape:1x24x24x24
Layer op2-conv_2d:
	Inputs=image_example1/average_pooling2d/AvgPool, dtype:int8, shape:1x24x24x24,image_example1/conv2d_1/Conv2D, dtype:int8, shape:16x3x3x24,image_example1/conv2d_1/BiasAdd/ReadVariableOp/resource, dtype:int32, shape:16
	Outputs=image_example1/conv2d_1/Relu;image_example1/conv2d_1/BiasAdd;image_example1/conv2d_1/Conv2D;image_example1/conv2d_1/BiasAdd/ReadVariableOp/resource, dtype:int8, shape:1x11x11x16
Layer op3-conv_2d:
	Inputs=image_example1/conv2d_1/Relu;image_example1/conv2d_1/BiasAdd;image_example1/conv2d_1/Conv2D;image_example1/conv2d_1/BiasAdd/ReadVariableOp/resource, dtype:int8, shape:1x11x11x16,image_example1/conv2d_2/Conv2D, dtype:int8, shape:24x3x3x16,image_example1/activation/Relu;image_example1/batch_normalization/FusedBatchNormV3;image_example1/conv2d_2/BiasAdd/ReadVariableOp/resource;image_example1/conv2d_2/BiasAdd;image_example1/conv2d_2/Conv2D, dtype:int32, shape:24
	Outputs=image_example1/activation/Relu;image_example1/batch_normalization/FusedBatchNormV3;image_example1/conv2d_2/BiasAdd/ReadVariableOp/resource;image_example1/conv2d_2/BiasAdd;image_example1/conv2d_2/Conv2D1, dtype:int8, shape:1x9x9x24
Layer op4-average_pool_2d:
	Inputs=image_example1/activation/Relu;image_example1/batch_normalization/FusedBatchNormV3;image_example1/conv2d_2/BiasAdd/ReadVariableOp/resource;image_example1/conv2d_2/BiasAdd;image_example1/conv2d_2/Conv2D1, dtype:int8, shape:1x9x9x24
	Outputs=image_example1/average_pooling2d_1/AvgPool, dtype:int8, shape:1x4x4x24
Layer op5-reshape:
	Inputs=image_example1/average_pooling2d_1/AvgPool, dtype:int8, shape:1x4x4x24,image_example1/flatten/Const, dtype:int32, shape:2
	Outputs=image_example1/flatten/Reshape, dtype:int8, shape:1x384
Layer op6-fully_connected:
	Inputs=image_example1/flatten/Reshape, dtype:int8, shape:1x384,image_example1/dense/MatMul, dtype:int8, shape:3x384,image_example1/dense/BiasAdd/ReadVariableOp/resource, dtype:int32, shape:3
	Outputs=image_example1/dense/MatMul;image_example1/dense/BiasAdd, dtype:int8, shape:1x3
Layer op7-softmax:
	Inputs=image_example1/dense/MatMul;image_example1/dense/BiasAdd, dtype:int8, shape:1x3
	Outputs=Identity_int8, dtype:int8, shape:1x3

Example 3: Add meta data to .tflite

This example loads .tflite model file and adds “metadata” to the .tflite.

# Import the TfliteModel class
from mltk.core import TfliteModel 

# Load the .tflite
tflite_model = TfliteModel.load_flatbuffer_file(tflite_path)

# Add meta data with the key: "my_metadata"
tflite_model.add_metadata('my_metadata', b'This is some arbitrary metadata that will be embedded into the .tflite')

# At this point, the metadata is only cached in RAM
# Save the model back to the .tflite file so that the added metadata persists
tflite_model.save()

# At a later time, the .tflite can be loaded from the .tflite file again
tflite_model = TfliteModel.load_flatbuffer_file(tflite_path)

# Retrieve the metadata from the model
my_metadata = tflite_model.get_metadata('my_metadata')
print(f'my_metadata={my_metadata}')
my_metadata=b'This is some arbitrary metadata that will be embedded into the .tflite'

Example 4: Add model parameters to the .tflite

This example loads .tflite model file and adds model parameters to it.

# Import the TfliteModel class
from mltk.core import TfliteModel, TfliteModelParameters 

# Load the .tflite
tflite_model = TfliteModel.load_flatbuffer_file(tflite_path)

# Load the model parameters
tflite_model_params = TfliteModelParameters.load_from_tflite_model(tflite_model)

# Add some parameters
tflite_model_params['my_bool'] = True 
tflite_model_params['my_int'] = 42
tflite_model_params['my_float'] = 3.14
tflite_model_params['my_str'] = 'This is a string parameter'
tflite_model_params['my_list_int'] = [1, 2, 3]
tflite_model_params['my_list_float'] = [1.1, 2.2, 3.3]
tflite_model_params['my_list_str'] = ['This', 'is', 'a', 'string', 'list']
tflite_model_params['my_bytes'] = bytearray([1, 2, 3, 4])

# Add the new model parameters to the tflite model
tflite_model_params.add_to_tflite_model(tflite_model)

# At this point, the model parameters are only cached in RAM
# Save the model back to the .tflite file so that the added metadata persists
tflite_model.save()

# At a later time, the .tflite can be loaded from the .tflite file again
tflite_model = TfliteModel.load_flatbuffer_file(tflite_path)

# Load the model parameters
tflite_model_params = TfliteModelParameters.load_from_tflite_model(tflite_model)

# This point, tflite_model_params is just a Python dictionary
# NOTE: .tflite models generated by the MLTK add additional model parameters by default
#       See: https://siliconlabs.github.io/mltk/docs/guides/model_parameters.html
for key, value in tflite_model_params.items():
    print(f'{key} = {value}')
name = image_example1
version = 1
classes = ['rock', 'paper', 'scissor']
hash = e8463b1e31855c5e6319493226b8b582
date = 2021-08-18T16:51:34.028Z
samplewise_norm.rescale = 0
samplewise_norm.mean = True
samplewise_norm.std = True
my_bool = True
my_int = 42
my_float = 3.140000104904175
my_str = This is a string parameter
my_list_int = [1, 2, 3]
my_list_float = [1.100000023841858, 2.200000047683716, 3.299999952316284]
my_list_str = ['This', 'is', 'a', 'string', 'list']
my_bytes = b'\x01\x02\x03\x04'

Example 5: Run inference

This package also allows for running inference on the .tflite.

import tensorflow as tf
import numpy as np
# Import the TfliteModel class
from mltk.core import TfliteModel 
# By default, this example uses the image_example1.tflite model
# which was train using the Rock,Paper,Scissors dataset
# You must change this to match your model's dataset
from mltk.datasets.image import rock_paper_scissors_v2

# Load the .tflite
tflite_model = TfliteModel.load_flatbuffer_file(tflite_path)


dataset_dir = rock_paper_scissors_v2.load_data()

def _load_sample(class_name):
    base_dir = f'{dataset_dir}/{class_name}'
    # Retrieve the first sample filename for the give class 
    filename = os.listdir(base_dir)[0]
    image_path = f'{base_dir}/{filename}'
    # Load the sample image
    img = tf.keras.preprocessing.image.load_img(image_path, color_mode = 'grayscale')
    # Convert the image to a numpy array
    img_array = tf.keras.preprocessing.image.img_to_array(img, dtype='uint8')

    # Normalize the image array
    # NOTE: This is how the image_example1.tflite model was trained
    #       This must be modified as necessary for your .tflite
    norm_img = (img_array - np.mean(img_array)) / np.std(img_array)

    # Ensure the data type if float32
    norm_img = norm_img.astype('float32')
    return norm_img

# Load a sample for each class type 
rock_sample = _load_sample('rock')
paper_sample = _load_sample('paper')
scissors_sample = _load_sample('scissor')

# Run inference on the "rock" sample
prep = tflite_model.predict(rock_sample)
print(f'Rock prediction: {prep}')

# Run inference on the "paper" sample
prep = tflite_model.predict(paper_sample)
print(f'Paper prediction: {prep}')

# Run inference on the "scissor" sample
prep = tflite_model.predict(scissors_sample)
print(f'Scissors prediction: {prep}')
Rock prediction: [ 127 -128 -128]
Paper prediction: [-128  122 -122]
Scissors prediction: [-128 -127  127]

Example 6: Access calculated layer parameters

Layer parameters calculated by Tensorflow-Lite Micro are also made accessible by the tflite_model package:

from mltk.core import (
    TfliteModel,
    TfliteConv2dLayer,
    TfliteFullyConnectedLayer,
    TflitePooling2dLayer
)

# Load the .tflite
tflite_model = TfliteModel.load_flatbuffer_file(tflite_path)

# Iterate through each of the model layers
# and access the calculated parameters for the supported layers
for layer in tflite_model.layers:
    if isinstance(layer, TfliteConv2dLayer):
        conv2d_params = layer.params
        per_channel_output_multiplier = conv2d_params.per_channel_output_multiplier
        per_channel_output_shift = conv2d_params.per_channel_output_shift
        output_offset = conv2d_params.output_offset
    elif isinstance(layer, TfliteFullyConnectedLayer):
        fully_connected_params = layer.params
        quantized_activation_min = fully_connected_params.quantized_activation_min
        quantized_activation_max = fully_connected_params.quantized_activation_max
    elif isinstance(layer, TflitePooling2dLayer):
        pool_params = layer.params
        padding_width = pool_params.padding.width
        padding_height = pool_params.padding.height

Example 7: Update model weights

The tflite_model package also allows for updating the models.

The model tensors and various parameters may be modified, then calling regenerate_flatbuffer will cause the underlying .tflite flatbuffer to be updated.

import numpy as np
from mltk.core import (TfliteModel, TfliteConv2dLayer)

# Load the .tflite
tflite_model = TfliteModel.load_flatbuffer_file(tflite_path)

# Retrieve the first layer of the model 
# (which happens to be a conv2d kernel if you're using the example .tflite model)
# NOTE: We use :TfliteConv2dLayer for typing hinting
#       So the IDE can access the properties specific to the conv2d layer
conv2d_layer:TfliteConv2dLayer = tflite_model.layers[0]

# Get the filters tensor
filters = conv2d_layer.filters_tensor

# Set the filters to zero
filters.data = np.zeros(filters.shape, dtype=filters.dtype)

# Update the quantization as well
quantization = filters.quantization
n_channels = conv2d_layer.output_data.shape[-1]
filters.quantization.zeropoint = np.zeros((n_channels,), dtype=np.int32)
filters.quantization.scale = np.zeros((n_channels,), dtype=np.float32)

# Update the underlying flatbuffer with our changes
tflite_model.regenerate_flatbuffer()

# Save the updated model
updated_tflite_path = os.path.normpath(f'{tempfile.gettempdir()}/modified_image_example1.tflite')
tflite_model.save(updated_tflite_path)


# Re-load the saved model and verify that the changes were actually saved
updated_tflite_model = TfliteModel.load_flatbuffer_file(updated_tflite_path)
updated_conv2d_layer:TfliteConv2dLayer = updated_tflite_model.layers[0]
updated_filters = conv2d_layer.filters_tensor
updated_quantization = updated_filters.quantization

assert np.allclose(updated_filters.data, filters.data)
assert np.allclose(updated_quantization.zeropoint, filters.quantization.zeropoint)
assert np.allclose(updated_quantization.scale, filters.quantization.scale)