Keyword Spotting with Transfer Learning¶
This tutorial describes how to use the weights from the pre-trained trained model keyword_spotting_mobilenetv2, as a starting point for training a new model, keyword_spotting_with_transfer_learning.py, to detect the keywords:
This process is known as Transfer Learning and can greatly improve training times as the new model can leverage the knowledge learnt from another model.
See the Keras Documentation for more details about how to use the Keras API to enable transfer learning.
Running this Tutorial¶
Before getting started, it is recommended to review the following documentation:
MLTK Overview - An overview of the core concepts used by the MLTK
Keyword Spotting Overview - An overview of how keyword spotting works
Keyword Spotting Tutorial - Detailed tutorial describing how to create a Keyword Spotting model using the MLTK
Before continuing, it helps to understand the basic idea of how a Convolutional Neural Network (CNN) works.
The following diagram illustrates a typical CNN model:
(By Aphex34 - Own work, CC BY-SA 4.0)
There are two basic parts:
Feature Extraction - Convolutional filters extract meaningful information (e.g. lines, shapes, textures, colors) from the input image. Multiple layers of filters are used to convert the raw image into a more abstract and compressed representation
Classifier - The final Convolutional layer is feed into a Fully Connected layer. The Fully Connected layer converts the abstract, compressed representation generated by the Convolutional layers into a probability distribution of the possible “class” or object type (e.g. dog, cat, fish) to which the input image belongs.
(Of course, it can get way more complex than this, but that is basically what is going on).
So how does the CNN model learn to extract the features and generate the probability distribution?
Each “layer” of the model has “trainable” parameters such as weights and filters. Initially, these weights and filters are set to random values. During model training, the weights and filters are adjusted so that the model’s predictions are as accurate as possible for the given parameters and dataset. At the end of training, the weights and filters that gave the best predictions are saved to a file and ultimately programmed to the embedded device.
Since the weights and filters are initially random, it can take a long time to train a model from scratch.
The idea behind transfer learning is to train a new model starting with the best weights and filters of a previously trained model. If the datasets are similar then the new model does not need to relearn how to extract the features (e.g. lines, shapes, textures, colors).
It just needs to learn how to map the abstract, compressed representation of the convolutional layers to the new “classes” or objects (e.g. car, truck, bike).
Base Model Overview¶
The pre-trained, base model, whose weights are transferred to the new model, may be found on Github: keyword_spotting_mobilenetv2.
This model is built using MobileNetv2, an industry-standard classification model developed by Google. MobileNetV2 is a useful model because it is generic enough that it can be applied to most classification tasks but still runs efficiently on embedded devices.
This model is designed to detect the following keywords:
Test model using PC microphone¶
If you have a microphone connected to your computer, you can optionally run the following command to see the model detect the keywords:
# Test the keyword_spotting_mobilenetv2 using the PC's microphone # We simulate the latency to be 130ms as that's the approximate latency # that would be seen on the development board !mltk classify_audio keyword_spotting_mobilenetv2 --latency 130
Test model using development board¶
If you have a supported development board (currently just the BRD2601), you can optionally run the following command to see the model detected keywords using the development board’s microphone:
# Test the keyword_spotting_mobilenetv2 using the development board's microphone # The red LED will turn on when a keyword is detected # The green LED will turn on when there's audio activity # NOTE: Your mouth should be ~2 inches from the microphone !mltk classify_audio keyword_spotting_mobilenetv2 --device --accelerator mvp
Configure Model Specification with Transfer Learning¶
The completed model specification used by this tutorial may be found on Github: keyword_spotting_with_transfer_learning.py
This model specification is very similar to the base model specification, keyword_spotting_mobilenetv2, with the following key differences:
Update the model’s description to help keep track of it in the field.
my_model.description = 'Keyword spotting classifier using transfer learning to detect: "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"'
Set epochs to small value¶
Since this model is leveraging the knowledge of the base model, it does not need many epochs to tune the model parameters (e.g. filters and weights)
# Since we're using transfer learning, we should only need # a small number of epochs to get a well trained model my_model.epochs = 10
Use LearningRateScheduler with small initial value¶
To have better control of the learning rate, we use a LearningRateScheduler. We also set the initial value to a small value.
def lr_schedule(epoch): # When using transfer learning, the initial learning rate should start at a fairly small value initial_learning_rate = 0.0005 decay_per_epoch = 0.95 lrate = initial_learning_rate * (decay_per_epoch ** epoch) return lrate my_model.lr_schedule = dict( schedule = lr_schedule, verbose = 1 )
Update the keywords to detect¶
Update the keywords we want this model to detect.
While the keywords “one” through “nine” are supported, this can greatly increase the training time as each keyword adds ~3k samples to the training dataset.
If you want to quickly train a model, it is recommended to only use a few keywords.
# Using all 9 keywords may take a LONG TIME to train # my_model.classes = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', '_unknown_', '_silence_'] # Using 3 keywords should take a much SHORTER TIME to train my_model.classes = ['one', 'two', 'three', '_unknown_', '_silence_']
Reduce the “unknown_class_percentage”¶
To help reduce training time, we reduce the size of the dataset by decreasing the size of the “unknown” data samples dynamically generated by the ParallelAudioDataGenerator
Load the weights from the base model¶
The most important change is actually loading the weights from the base model into this model.
To loads the weights:
Instantiate a MobileNetV2() instance
Load the keyword_spotting_mobilenetv2 MtlkObject
Retrieve the path to the model’s .h5 file in the model’s archive. The .h5 model file contains the trained weights
Load the trained weights into the MobileNetV2() instance
Compile and return the model instance with the weights from the base model
NOTE: Keras also describes ways of “freezing” layers of the model so that the weights do not change during training. More details here.
def my_model_builder(model: MyModel): # Create an instance of the MobileNetV2 # NOTE: This should have similar parameters to the keyword_spotting_mobilenetv2 model # since we're transferring weights from it keras_model = MobileNetV2( input_shape=model.input_shape, classes=model.n_classes, alpha=0.15, last_block_filters=384, include_top=True, weights=None ) # Load the "keyword_spotting_mobilenetv2" model # We want to transfer its weights to this model # In this way, this new model can start with the knowledge # that the keyword_spotting_mobilenetv2 model already knows # NOTE: This step is not needed if you already have a .h5 file base_mltk_model = load_mltk_model('keyword_spotting_mobilenetv2') # Get the file path to the .h5 file found in the keyword_spotting_mobilenetv2 model archive # The .h5 file contains the trained weights we want to transfer to this model base_model_h5_path = base_mltk_model.h5_archive_path # Load the keyword_spotting_mobilenetv2 weights into this model keras_model.load_weights( base_model_h5_path, by_name=True, skip_mismatch=True # We need to skip mismatches in case the number of classes is different ) # NOTE: The https://keras.io/guides/transfer_learning recommends # "freezing" layers of the base model during training, however, in this instance, # it was found that making all layers trainable gave better performance. keras_model.compile( loss=model.loss, optimizer=model.optimizer, metrics=model.metrics ) return keras_model
Train the Model¶
With the model specification complete, invoke the model training. Since we’re only training for 10 epochs, this should complete relatively quickly.
!mltk train keyword_spotting_with_transfer_learning
Train in cloud¶
Alternatively, you can vastly improve the model training time by training this model in the “cloud”.
See the tutorial: Cloud Training with vast.ai for more details.
Test the model¶
With the model trained, we can see how well it runs on the development board by issuing the command:
# Test the keyword_spotting_with_transfer_learning using the development board's microphone # The red LED will turn on when a keyword is detected # The green LED will turn on when there's audio activity # NOTE: Your mouth must be ~2 inches for the board's microphone !mltk classify_audio keyword_spotting_with_transfer_learning --device --accelerator mvp