keyword_spotting_pacman

This model is a CNN classifier to detect the keywords:

  • left

  • right

  • up

  • down

  • stop

  • go

It is specially trained to handle the background noise generated by the video game “Pac-Man”.

This model specification script is designed to work with the Keyword Spotting Pac-Man tutorial.

Dataset

This uses the mltk.datasets.audio.speech_commands.speech_commands_v2 dataset provided by Google. Plus an additional background noise recorded when play the video game “Pac-Man”.

Preprocessing

This uses the mltk.core.preprocess.audio.parallel_generator.ParallelAudioDataGenerator with the mltk.core.preprocess.audio.audio_feature_generator.AudioFeatureGenerator settings:

  • sample_rate: 16kHz

  • sample_length: 700ms

  • window size: 20ms

  • window step: 10ms

  • n_channels: 70

Commands

# Do a "dry run" test training of the model
mltk train keyword_spotting_pacman-test

# Train the model
mltk train keyword_spotting_pacman

# Evaluate the trained model .tflite model
mltk evaluate keyword_spotting_pacman --tflite

# Profile the model in the MVP hardware accelerator simulator
mltk profile keyword_spotting_pacman --accelerator MVP

# Profile the model on a physical development board
mltk profile keyword_spotting_pacman  --accelerator MVP --device

# Run the model in the audio classifier on the local PC
mltk classify_audio keyword_spotting_pacman --verbose

# Run the model in the audio classifier on the physical device
mltk classify_audio keyword_spotting_pacman --device --verbose

Model Summary

mltk summarize keyword_spotting_pacman --tflite

+-------+-----------------+-----------------+-----------------+-----------------------------------------------------+
| Index | OpCode          | Input(s)        | Output(s)       | Config                                              |
+-------+-----------------+-----------------+-----------------+-----------------------------------------------------+
| 0     | conv_2d         | 69x70x1 (int8)  | 69x70x7 (int8)  | Padding:same stride:1x1 activation:relu             |
|       |                 | 3x3x1 (int8)    |                 |                                                     |
|       |                 | 7 (int32)       |                 |                                                     |
| 1     | max_pool_2d     | 69x70x7 (int8)  | 34x35x7 (int8)  | Padding:valid stride:2x2 filter:2x2 activation:none |
| 2     | conv_2d         | 34x35x7 (int8)  | 34x35x14 (int8) | Padding:same stride:1x1 activation:relu             |
|       |                 | 3x3x7 (int8)    |                 |                                                     |
|       |                 | 14 (int32)      |                 |                                                     |
| 3     | max_pool_2d     | 34x35x14 (int8) | 17x17x14 (int8) | Padding:valid stride:2x2 filter:2x2 activation:none |
| 4     | conv_2d         | 17x17x14 (int8) | 17x17x28 (int8) | Padding:same stride:1x1 activation:relu             |
|       |                 | 3x3x14 (int8)   |                 |                                                     |
|       |                 | 28 (int32)      |                 |                                                     |
| 5     | max_pool_2d     | 17x17x28 (int8) | 8x8x28 (int8)   | Padding:valid stride:2x2 filter:2x2 activation:none |
| 6     | conv_2d         | 8x8x28 (int8)   | 8x8x28 (int8)   | Padding:same stride:1x1 activation:relu             |
|       |                 | 3x3x28 (int8)   |                 |                                                     |
|       |                 | 28 (int32)      |                 |                                                     |
| 7     | max_pool_2d     | 8x8x28 (int8)   | 4x4x28 (int8)   | Padding:valid stride:2x2 filter:2x2 activation:none |
| 8     | conv_2d         | 4x4x28 (int8)   | 4x4x28 (int8)   | Padding:same stride:1x1 activation:relu             |
|       |                 | 3x3x28 (int8)   |                 |                                                     |
|       |                 | 28 (int32)      |                 |                                                     |
| 9     | max_pool_2d     | 4x4x28 (int8)   | 2x2x28 (int8)   | Padding:valid stride:2x2 filter:2x2 activation:none |
| 10    | reshape         | 2x2x28 (int8)   | 112 (int8)      | Type=none                                           |
|       |                 | 2 (int32)       |                 |                                                     |
| 11    | fully_connected | 112 (int8)      | 7 (int8)        | Activation:none                                     |
|       |                 | 112 (int8)      |                 |                                                     |
|       |                 | 7 (int32)       |                 |                                                     |
| 12    | softmax         | 7 (int8)        | 7 (int8)        | Type=softmaxoptions                                 |
+-------+-----------------+-----------------+-----------------+-----------------------------------------------------+
Total MACs: 2.939 M
Total OPs: 5.997 M
Name: keyword_spotting_pacman
Version: 2
Description: Keyword spotting classifier to detect: left, right, up, down, stop, go
Classes: left, right, up, down, stop, go, _silence_
hash: ba1f45639f61f277834e4c8ee71ae040
date: 2022-06-24T15:44:22.413Z
runtime_memory_size: 45475
average_window_duration_ms: 150
detection_threshold: 205
detection_threshold_list: [240, 240, 200, 215, 230, 200, 255]
suppression_ms: 1
minimum_count: 2
volume_gain: 0.0
latency_ms: 10
verbose_model_output_logs: False
samplewise_norm.rescale: 0.0
samplewise_norm.mean_and_std: False
fe.sample_rate_hz: 16000
fe.sample_length_ms: 700
fe.window_size_ms: 20
fe.window_step_ms: 10
fe.filterbank_n_channels: 70
fe.filterbank_upper_band_limit: 8000.0
fe.filterbank_lower_band_limit: 150.0
fe.noise_reduction_enable: False
fe.noise_reduction_smoothing_bits: 10
fe.noise_reduction_even_smoothing: 0.004000000189989805
fe.noise_reduction_odd_smoothing: 0.004000000189989805
fe.noise_reduction_min_signal_remaining: 0.05000000074505806
fe.pcan_enable: False
fe.pcan_strength: 0.949999988079071
fe.pcan_offset: 80.0
fe.pcan_gain_bits: 21
fe.log_scale_enable: True
fe.log_scale_shift: 6
fe.activity_detection_enable: False
fe.activity_detection_alpha_a: 0.5
fe.activity_detection_alpha_b: 0.800000011920929
fe.activity_detection_arm_threshold: 0.75
fe.activity_detection_trip_threshold: 0.800000011920929
fe.dc_notch_filter_enable: True
fe.dc_notch_filter_coefficient: 0.949999988079071
fe.quantize_dynamic_scale_enable: True
fe.quantize_dynamic_scale_range_db: 40.0
fe.fft_length: 512
.tflite file size: 33.9kB

Model Profiling Report

# Profile on physical EFR32xG24 using MVP accelerator
mltk profile keyword_spotting_pacman --device --accelerator MVP

 Profiling Summary
 Name: keyword_spotting_pacman
 Accelerator: MVP
 Input Shape: 1x69x70x1
 Input Data Type: int8
 Output Shape: 1x7
 Output Data Type: int8
 Flash, Model File Size (bytes): 33.9k
 RAM, Runtime Memory Size (bytes): 51.8k
 Operation Count: 6.1M
 Multiply-Accumulate Count: 2.9M
 Layer Count: 13
 Unsupported Layer Count: 0
 Accelerator Cycle Count: 3.0M
 CPU Cycle Count: 265.0k
 CPU Utilization (%): 8.6
 Clock Rate (hz): 78.0M
 Time (s): 39.4m
 Ops/s: 154.8M
 MACs/s: 74.0M
 Inference/s: 25.3

 Model Layers
 +-------+-----------------+--------+--------+------------+------------+----------+-------------------------+--------------+-----------------------------------------------------+
 | Index | OpCode          | # Ops  | # MACs | Acc Cycles | CPU Cycles | Time (s) | Input Shape             | Output Shape | Options                                             |
 +-------+-----------------+--------+--------+------------+------------+----------+-------------------------+--------------+-----------------------------------------------------+
 | 0     | conv_2d         | 710.0k | 304.3k | 486.7k     | 30.6k      | 6.3m     | 1x69x70x1,7x3x3x1,7     | 1x69x70x7    | Padding:same stride:1x1 activation:relu             |
 | 1     | max_pool_2d     | 33.3k  | 0      | 50.1k      | 14.0k      | 720.0u   | 1x69x70x7               | 1x34x35x7    | Padding:valid stride:2x2 filter:2x2 activation:none |
 | 2     | conv_2d         | 2.1M   | 1.0M   | 1.3M       | 29.6k      | 16.9m    | 1x34x35x7,14x3x3x7,14   | 1x34x35x14   | Padding:same stride:1x1 activation:relu             |
 | 3     | max_pool_2d     | 16.2k  | 0      | 12.2k      | 13.9k      | 240.0u   | 1x34x35x14              | 1x17x17x14   | Padding:valid stride:2x2 filter:2x2 activation:none |
 | 4     | conv_2d         | 2.1M   | 1.0M   | 738.3k     | 30.2k      | 9.4m     | 1x17x17x14,28x3x3x14,28 | 1x17x17x28   | Padding:same stride:1x1 activation:relu             |
 | 5     | max_pool_2d     | 7.2k   | 0      | 5.6k       | 26.4k      | 360.0u   | 1x17x17x28              | 1x8x8x28     | Padding:valid stride:2x2 filter:2x2 activation:none |
 | 6     | conv_2d         | 908.5k | 451.6k | 293.2k     | 30.3k      | 3.8m     | 1x8x8x28,28x3x3x28,28   | 1x8x8x28     | Padding:same stride:1x1 activation:relu             |
 | 7     | max_pool_2d     | 1.8k   | 0      | 1.5k       | 26.3k      | 360.0u   | 1x8x8x28                | 1x4x4x28     | Padding:valid stride:2x2 filter:2x2 activation:none |
 | 8     | conv_2d         | 227.1k | 112.9k | 62.1k      | 27.9k      | 900.0u   | 1x4x4x28,28x3x3x28,28   | 1x4x4x28     | Padding:same stride:1x1 activation:relu             |
 | 9     | max_pool_2d     | 448.0  | 0      | 532.0      | 26.3k      | 330.0u   | 1x4x4x28                | 1x2x2x28     | Padding:valid stride:2x2 filter:2x2 activation:none |
 | 10    | reshape         | 0      | 0      | 0          | 1.0k       | 0        | 1x2x2x28,2              | 1x112        | Type=none                                           |
 | 11    | fully_connected | 1.6k   | 784.0  | 1.2k       | 2.2k       | 60.0u    | 1x112,7x112,7           | 1x7          | Activation:none                                     |
 | 12    | softmax         | 35.0   | 0      | 0          | 6.2k       | 60.0u    | 1x7                     | 1x7          | Type=softmaxoptions                                 |
 +-------+-----------------+--------+--------+------------+------------+----------+-------------------------+--------------+-----------------------------------------------------+

Model Diagram

mltk view keyword_spotting_pacman --tflite

Model Specification

# Import the Tensorflow packages
# required to build the model layout
import os
import librosa
import random
from typing import List, Tuple

import numpy as np
import tensorflow as tf
from tensorflow.keras import regularizers
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (
    Dense,
    Activation,
    Flatten,
    BatchNormalization,
    Conv2D,
    MaxPooling2D,
    Dropout
)

# Import the MLTK model object
# and necessary mixins
# Later in this script we configure the various properties
from mltk import MLTK_DIR
from mltk.core import (
    MltkModel,
    TrainMixin,
    AudioDatasetMixin,
    EvaluateClassifierMixin
)

# Import the Google speech_commands dataset package
# This manages downloading and extracting the dataset
from mltk.datasets.audio.speech_commands import speech_commands_v2

# Import the ParallelAudioDataGenerator
# This has two main jobs:
# 1. Process the Google speech_commands dataset and apply random augmentations during training
# 2. Generate a spectrogram using the AudioFeatureGenerator from each augmented audio sample
#    and give the spectrogram to Tensorflow for model training
from mltk.core.preprocess.audio.parallel_generator import ParallelAudioDataGenerator, ParallelProcessParams
# Import the AudioFeatureGeneratorSettings which we'll configure
# and give to the ParallelAudioDataGenerator
from mltk.core.preprocess.audio.audio_feature_generator import AudioFeatureGeneratorSettings
from mltk.utils.archive_downloader import download_verify_extract


# Define a custom model object with the following 'mixins':
# - TrainMixin        - Provides classifier model training operations and settings
# - AudioDatasetMixin - Provides audio data generation operations and settings
# - EvaluateClassifierMixin     - Provides classifier evaluation operations and settings
# @mltk_model # NOTE: This tag is required for this model be discoverable
class MyModel(
    MltkModel,
    TrainMixin,
    AudioDatasetMixin,
    EvaluateClassifierMixin
):
    pass

# Instantiate our custom model object
# The rest of this script simply configures the properties
# of our custom model object
my_model = MyModel()


#################################################
# General Settings

# For better tracking, the version should be incremented any time a non-trivial change is made
# NOTE: The version is optional and not used directly used by the MLTK
my_model.version = 1
# Provide a brief description about what this model models
# This description goes in the "description" field of the .tflite model file
my_model.description = 'Keyword spotting classifier to detect: left, right, up, down, stop, go with Pac-Man video game background noise'

#################################################
# Training Basic Settings

# This specifies the number of times we run the training
# samples through the model to update the model weights.
# Typically, a larger value leads to better accuracy at the expense of training time.
# Set to -1 to use the early_stopping callback and let the scripts
# determine how many epochs to train for (see below).
# Otherwise set this to a specific value (typically 40-200)
my_model.epochs = 80
# Specify how many samples to pass through the model
# before updating the training gradients.
# Typical values are 10-64
# NOTE: Larger values require more memory and may not fit on your GPU
my_model.batch_size = 16
# This specifies the algorithm used to update the model gradients
# during training. Adam is very common
# See https://www.tensorflow.org/api_docs/python/tf/keras/optimizers
my_model.optimizer = 'adam'
# List of metrics to be evaluated by the model during training and testing
my_model.metrics = ['accuracy']
# The "loss" function used to update the weights
# This is a classification problem with more than two labels so we use categorical_crossentropy
# See https://www.tensorflow.org/api_docs/python/tf/keras/losses
my_model.loss = 'categorical_crossentropy'


#################################################
# Training callback Settings

# Generate checkpoints every time the validation accuracy improves
# See https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint
my_model.checkpoint['monitor'] =  'val_accuracy'

# https://keras.io/api/callbacks/reduce_lr_on_plateau/
# If the test loss doesn't improve after 'patience' epochs
# then decrease the learning rate by 'factor'
my_model.reduce_lr_on_plateau = dict(
  monitor='loss',
  factor = 0.95,
  min_delta=0.001,
  patience = 1,
  verbose=1,
)

# If the  accuracy doesn't improve after 15 epochs then stop training
# https://keras.io/api/callbacks/early_stopping/
my_model.early_stopping = dict(
  monitor = 'accuracy',
  patience = 15,
  verbose=1
)



#################################################
# TF-Lite converter settings

# These are the settings used to quantize the model
# We want all the internal ops as well as
# model input/output to be int8
my_model.tflite_converter['optimizations'] = [tf.lite.Optimize.DEFAULT]
my_model.tflite_converter['supported_ops'] = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# NOTE: A float32 model input/output is also possible
my_model.tflite_converter['inference_input_type'] = np.int8
my_model.tflite_converter['inference_output_type'] = np.int8
# Automatically generate a representative dataset from the validation data
my_model.tflite_converter['representative_dataset'] = 'generate'



#################################################
# Audio Data Provider Settings

# Specify the dataset
# NOTE: This can also be an absolute path to a directory
#       or a Python function
# See: https://siliconlabs.github.io/mltk/docs/python_api/mltk_model/audio_dataset_mixin.html#mltk.core.AudioDatasetMixin.dataset
my_model.dataset = speech_commands_v2
# We're using a 'categorical_crossentropy' loss
# so must also use a `categorical` class mode for the data generation
my_model.class_mode = 'categorical'


# Add the direction keywords plus a _silence_ meta class
my_model.classes = ['left', 'right', 'up', 'down', 'stop', 'go', '_silence_']

# The numbers of samples for each class is different
# Then ensures each class contributes equally to training the model
my_model.class_weights = 'balanced'


#################################################
# AudioFeatureGenerator Settings
#
# These are the settings used by the AudioFeatureGenerator
# to generate spectrograms from the audio samples
# These settings must be used during modeling training
# AND by embedded device at runtime
#
# See https://siliconlabs.github.io/mltk/docs/audio/audio_feature_generator.html
frontend_settings = AudioFeatureGeneratorSettings()

# The sample rate, sample length, window size and window step
# control the width of the generated spectrogram
frontend_settings.sample_rate_hz = 16000  # This can also be 16k for slightly better performance at the cost of more RAM
frontend_settings.sample_length_ms = 700 # 0.7s sample size
frontend_settings.window_size_ms = 20
frontend_settings.window_step_ms = 10
frontend_settings.filterbank_n_channels = 70 # The number of frequency bins. This controls the height of the spectrogram
frontend_settings.filterbank_upper_band_limit = frontend_settings.sample_rate_hz // 2 # 8kHz
frontend_settings.filterbank_lower_band_limit = 150 # The dev board mic seems to have a lot of noise at lower frequencies
frontend_settings.noise_reduction_enable = False # Disable the noise reduction block
frontend_settings.pcan_enable = False # Disable the PCAN block

frontend_settings.dc_notch_filter_enable = True # Enable the DC notch filter
frontend_settings.dc_notch_filter_coefficient = 0.95

# Enable dynamic quantization
# This quantizes the generated spectrogram from uint16 to int8
frontend_settings.quantize_dynamic_scale_enable = True
frontend_settings.quantize_dynamic_scale_range_db = 40.0


PACMAN_BACKGROUND_NOISE_DIR = download_verify_extract(
    url='https://github.com/SiliconLabs/mltk_assets/raw/master/datasets/recorded_pacman_game_play.7z',
    dest_subdir='datasets/recorded_pacman_game_play',
    file_hash='749F552BC2ABA11E618969D8B0F6E5BDD62AC7A2',
    show_progress=False,
    remove_root_dir=False
)
PACMAN_BACKGROUND_NOISE_PATH = f'{PACMAN_BACKGROUND_NOISE_DIR}/recorded_pacman_game_play.wav'


def get_batches_samples(
    batch_index:int,
    filenames:List[str],
    classes:List[int],
    params:ParallelProcessParams
) -> Tuple[int, Tuple[np.ndarray, np.ndarray]]:
    """This slightly modified from the standard function that comes with the MLTK:
    https://github.com/siliconlabs/mltk/blob/master/mltk/core/preprocess/audio/parallel_generator/iterator.py#L241

    80% of the time it adds a snippet of the Pac-Man background noise to the sample.

    """

    if 'game' not in params.audio_data_generator.bg_noises:
        PACMAN_BACKGROUND_NOISE, orignal_sr = librosa.load(PACMAN_BACKGROUND_NOISE_PATH, sr=frontend_settings.sample_rate_hz, mono=True, dtype='float32')
        params.audio_data_generator.bg_noises['game'] = PACMAN_BACKGROUND_NOISE


    batch_shape = (len(filenames),) + params.sample_shape
    batch_x = np.zeros(batch_shape, dtype=params.dtype)

    for i, (filename, class_id) in enumerate(zip(filenames, classes)):
        if filename:
            filepath = os.path.join(params.directory, filename)
            x, orignal_sr = librosa.load(filepath, sr=None, mono=True, dtype='float32')

        else:
            orignal_sr = 16000
            x = np.zeros((orignal_sr,), dtype='float32')

        transform_params = params.audio_data_generator.get_random_transform()
        add_game_background_noise = random.uniform(0, 1) < .8  # Add the game background noise 80% of the time
        if add_game_background_noise:
            transform_params['noise_color'] = None
            transform_params['bg_noise'] = 'game'

        # Apply any audio augmentations
        # NOTE: If transform_params =  default_transform
        #       Then the audio sample is simply cropped/padded to fit the expected sample length
        x = params.audio_data_generator.apply_transform(x, orignal_sr, transform_params)

        # After point through the frontend,
        # x = [height, width] dtype=self.dtype
        x = params.audio_data_generator.apply_frontend(x, dtype=params.dtype)

        # Convert the sample's shape from [height, width]
        # to [height, width, 1]
        batch_x[i] = np.expand_dims(x, axis=-1)


    batch_y = np.zeros((len(batch_x), len(params.class_indices)), dtype=params.dtype)
    for i, class_id in enumerate(classes):
        batch_y[i, class_id] = 1.

    return batch_index, (batch_x, batch_y)


#################################################
# ParallelAudioDataGenerator Settings
#
# Configure the data generator settings
# This specifies how to augment the training samples
# See the command: "mltk view_audio"
# to get a better idea of how these augmentations affect
# the samples
my_model.datagen = ParallelAudioDataGenerator(
    dtype=my_model.tflite_converter['inference_input_type'],
    frontend_settings=frontend_settings,
    cores=0.75, # Adjust this as necessary for your PC setup
    debug=False, # Set this to true to enable debugging of the generator
    max_batches_pending=32,  # Adjust this as necessary for your PC setup (smaller -> less RAM)
    validation_split= 0.15,
    validation_augmentation_enabled=True,
    get_batch_function=get_batches_samples,
    samplewise_center=False,
    samplewise_std_normalization=False,
    rescale=None,
    #unknown_class_percentage=2.0, # Increasing this may help model robustness at the expense of training time
    silence_class_percentage=1,
    offset_range=(0.0,1.0),
    trim_threshold_db=20,
    noise_colors='all',
    loudness_range=(0.6, 1.0),
    speed_range=(0.9,1.1),
    #pitch_range=(-1,1),
    #vtlp_range=(1.0,1.1),
    bg_noise_range=(0.1,0.8),
    bg_noise_dir='_background_noise_' # This is a directory provided by the google speech commands dataset, can also provide an absolute path
)




#################################################
# Model Layout
#
# This defines the actual model layout
# using the Keras API.
# This particular model is a relatively standard
# sequential Convolution Neural Network (CNN).
#
# It is important to the note the usage of the
# "model" argument.
# Rather than hardcode values, the model is
# used to build the model, e.g.:
# Dense(model.n_classes)
#
# This way, the various model properties above can be modified
# without having to re-write this section.
#
def my_model_builder(model: MyModel):
    weight_decay = 1e-4
    regularizer = regularizers.l2(weight_decay)
    input_shape = model.input_shape
    filters = 7

    keras_model = Sequential(name=model.name, layers = [
        Conv2D(filters, (3,3), padding='same', input_shape=input_shape),
        BatchNormalization(),
        Activation('relu'),
        MaxPooling2D(2,2),

        Conv2D(2*filters,(3,3), padding='same'),
        BatchNormalization(),
        Activation('relu'),
        MaxPooling2D(2,2),

        Conv2D(4*filters, (3,3), padding='same'),
        BatchNormalization(),
        Activation('relu'),
        MaxPooling2D(2,2),

        Conv2D(4*filters, (3,3), padding='same'),
        BatchNormalization(),
        Activation('relu'),
        MaxPooling2D(2,2),

        Conv2D(4*filters, (3,3), padding='same'),
        BatchNormalization(),
        Activation('relu'),
        MaxPooling2D(2,2),

        #Dropout(0.5),

        Flatten(),
        Dense(model.n_classes, activation='softmax')
    ])

    keras_model.compile(
        loss=model.loss,
        optimizer=model.optimizer,
        metrics=model.metrics
    )

    return keras_model

my_model.build_model_function = my_model_builder




#################################################
# Audio Classifier Settings
#
# These are additional parameters to include in
# the generated .tflite model file.
# The settings are used by the "classify_audio" command
# or audio_classifier example application.
# NOTE: Corresponding command-line options will override these values.


# Controls the smoothing.
# Drop all inference results that are older than <now> minus window_duration
# Longer durations (in milliseconds) will give a higher confidence that the results are correct, but may miss some commands
my_model.model_parameters['average_window_duration_ms'] = 150

# Define a specific detection threshold for each class
my_model.model_parameters['detection_threshold_list'] = [240, 240, 200, 215, 230, 200, 255]

# Amount of milliseconds to wait after a keyword is detected before detecting new keywords
# Since we're using the audio detection block, we want this to be as short as possible
my_model.model_parameters['suppression_ms'] = 1

# The minimum number of inference results to average when calculating the detection value
my_model.model_parameters['minimum_count'] = 2

# Set the volume gain scaler (i.e. amplitude) to apply to the microphone data. If 0 or omitted, no scaler is applied
my_model.model_parameters['volume_gain'] = 0.0

# This the amount of time in milliseconds between audio processing loops
# Since we're using the audio detection block, we want this to be as short as possible
my_model.model_parameters['latency_ms'] = 10

# Enable verbose inference results
my_model.model_parameters['verbose_model_output_logs'] = False



##########################################################################################
# The following allows for running this model training script directly, e.g.:
# python keyword_spotting_pacman.py
#
# Note that this has the same functionality as:
# mltk train keyword_spotting_pacman
#
if __name__ == '__main__':
    import mltk.core as mltk_core
    from mltk import cli

    # Setup the CLI logger
    cli.get_logger(verbose=False)

    # If this is true then this will do a "dry run" of the model testing
    # If this is false, then the model will be fully trained
    test_mode_enabled = True

    # Train the model
    # This does the same as issuing the command: mltk train keyword_spotting_pacman-test --clean
    train_results = mltk_core.train_model(my_model, clean=True, test=test_mode_enabled)
    print(train_results)

    # Evaluate the model against the quantized .h5 (i.e. float32) model
    # This does the same as issuing the command: mltk evaluate keyword_spotting_pacman-test
    tflite_eval_results = mltk_core.evaluate_model(my_model, verbose=True, test=test_mode_enabled)
    print(tflite_eval_results)

    # Profile the model in the simulator
    # This does the same as issuing the command: mltk profile keyword_spotting_pacman-test
    profiling_results = mltk_core.profile_model(my_model, test=test_mode_enabled)
    print(profiling_results)