from typing import Union
import logging
import os
import json
import pprint
from tensorflow import keras
import matplotlib.pyplot as plt
from sklearn.utils.class_weight import compute_class_weight
from mltk.utils.python import contains_class_type, prepend_exception_msg
from mltk.utils.path import clean_directory
from mltk.utils import gpu
from .model import (
MltkModel,
MltkModelEvent,
KerasModel,
DatasetMixin,
TrainMixin,
load_mltk_model,
)
from .utils import get_mltk_logger
from .summarize_model import summarize_model
from .quantize_model import quantize_model
from .training_results import TrainingResults
[docs]def train_model(
model: Union[MltkModel, str],
weights: str = None,
epochs: int = None,
resume_epoch: int=0,
verbose: bool=None,
clean:bool=False,
quantize:bool=True,
create_archive:bool=True,
show:bool=False,
test:bool=False,
post_process:bool=False
) -> TrainingResults:
"""Train a model using Keras and Tensorflow
.. seealso::
* `Model Training Guide <https://siliconlabs.github.io/mltk/docs/guides/model_training.html>`_
* `Model Training API Examples <https://siliconlabs.github.io/mltk/mltk/examples/train_model.html>`_
* `KerasModel.fit() <https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit>`_
Args:
model: :py:class:`mltk.core.MltkModel` instance, name of MLTK model, path to model specification script(.py)
__Note:__ If the model is in "test mode" then the model will train for 1 epoch
weights: Optional file path of model weights to load before training
epochs: Optional, number of epochs to train model. This overrides the mltk_model.epochs attribute
resume_epoch: Optional, resuming training at the given epoch
verbose: Optional, Verbosely print to logger while training
clean: Optional, Clean the log directory before training
quantize: Optional, quantize the model after training successfully completes
create_archive: Optional, create an archive (.mltk.zip) of the training results and generated model files
show: Optional, show the training results diagram
test: Optional, load the model in "test mode" if true.
post_process: This allows for post-processing the training results (e.g. uploading to a cloud) if supported by the given MltkModel
Returns:
The model TrainingResults
"""
if isinstance(model, MltkModel):
mltk_model = model
if test:
mltk_model.enable_test_mode()
elif isinstance(model, str):
if model.endswith(('.tflite', '.h5', '.zip')):
raise ValueError(
'Must provide name of MLTK model '
'or path model specification script(.py)'
)
mltk_model = load_mltk_model(model, test=test)
else:
raise ValueError(
'Must provide MltkModel instance, name of MLTK model, or path to '
'model specification script(.py)'
)
if not isinstance(mltk_model, TrainMixin):
raise ValueError('model argument must be an MltkModel instance that inherits TrainMixin')
if not isinstance(mltk_model, DatasetMixin):
raise ValueError('model argument must be an MltkModel instance that inherits DatasetMixin')
logger = get_mltk_logger()
mltk_model.trigger_event(MltkModelEvent.TRAIN_STARTUP, post_process=post_process, logger=logger)
# Ensure the MltkModel archive is writable before we start training
if create_archive:
mltk_model.check_archive_file_is_writable(throw_exception=True)
# Clean the log directory if necessary
if clean is not None:
_clear_log_directory(mltk_model, logger=logger, recursive=clean)
# Create the training logger
logger = mltk_model.create_logger('train', parent=logger)
gpu.initialize(logger=logger)
try:
mltk_model.load_dataset(subset='training', logger=logger, test=mltk_model.test_mode_enabled)
except Exception as e:
prepend_exception_msg(e, 'Failed to load model training dataset')
raise
mltk_model.trigger_event(MltkModelEvent.BEFORE_BUILD_TRAIN_MODEL, logger=logger)
# Build the MLTK model's corresponding Keras model
try:
keras_model = mltk_model.build_model_function(mltk_model)
except Exception as e:
prepend_exception_msg(e, 'Failed to build Keras model')
raise
mltk_model.trigger_event(
MltkModelEvent.AFTER_BUILD_TRAIN_MODEL,
keras_model=keras_model,
logger=logger
)
# Load the weights into the model if necessary
try:
if weights:
weights_path = mltk_model.get_weights_path(weights)
logger.info(f'Loading weights file: {weights_path}')
keras_model.load_weights(weights_path)
except Exception as e:
prepend_exception_msg(e, 'Failed to load weights into Keras model')
raise
# Generate a summary of the model
try:
summary = summarize_model(
mltk_model,
built_model=keras_model
)
logger.info(summary)
with open(f'{mltk_model.log_dir}/{mltk_model.name}.h5.summary.txt', 'w') as f:
f.write(summary)
except Exception as e:
logger.debug(f'Failed to generate model summary, err: {e}', exc_info=e)
logger.info(mltk_model.summarize_dataset())
epochs = epochs or mltk_model.epochs
callbacks = _get_keras_callbacks(mltk_model, epochs=epochs, logger=logger)
epochs = _get_epochs(mltk_model, epochs=epochs, callbacks=callbacks, logger=logger)
initial_epoch = _try_resume_training(mltk_model, keras_model=keras_model, epochs=epochs, resume_epoch=resume_epoch, logger=logger)
try:
class_weights = compute_class_weights(mltk_model, logger=logger)
except Exception as e:
class_weights = None
logger.warning("Failed to compute class weights\nSet my_model.class_weights = 'none' to disable", exc_info=e)
fit_kwargs = dict(
x=mltk_model.x,
y=mltk_model.y,
batch_size=mltk_model.batch_size,
steps_per_epoch=mltk_model.steps_per_epoch,
validation_split=mltk_model.validation_split,
validation_data=mltk_model.validation_data,
validation_steps=mltk_model.validation_steps,
validation_batch_size=mltk_model.validation_batch_size,
validation_freq=mltk_model.validation_freq,
shuffle=mltk_model.shuffle,
class_weight=class_weights,
sample_weight=mltk_model.sample_weight,
epochs=epochs,
initial_epoch=initial_epoch,
callbacks=callbacks,
verbose=0 if verbose is False else 1,
)
fit_kwargs.update(mltk_model.train_kwargs)
mltk_model.trigger_event(MltkModelEvent.BEFORE_TRAIN, fit_kwargs=fit_kwargs, logger=logger)
logger.debug(f'Train kwargs:\n{pprint.pformat(fit_kwargs)}')
logger.info('Starting model training ...')
training_history = keras_model.fit(
**fit_kwargs
)
mltk_model.trigger_event(MltkModelEvent.AFTER_TRAIN, training_history=training_history, logger=logger)
try:
mltk_model.unload_dataset()
except Exception as e:
logger.debug(f'Failed to unload dataset, err: {e}', exc_info=e)
keras_model = _save_keras_model_file(mltk_model, keras_model, logger=logger)
results = _save_training_results(
mltk_model,
keras_model,
training_history,
logger=logger,
show=show,
)
if create_archive:
_create_model_archive(mltk_model, logger)
# Quantize the trained model
if quantize and mltk_model.tflite_converter:
try:
quantize_model(
mltk_model,
keras_model=results.keras_model,
update_archive=create_archive
)
except Exception as e:
prepend_exception_msg(e, 'Failed to quantize model')
raise
if mltk_model.on_training_complete is not None:
try:
mltk_model.on_training_complete(results)
except Exception as e:
logger.warning(f'Exception during on_training_complete() callback, err: {e}', exc_info=e)
logger.info('Training complete')
logger.info(f'Training logs here: {mltk_model.log_dir}')
if create_archive:
logger.info(f'Trained model files here: {mltk_model.archive_path}')
logger.close()
if show:
plt.show(block=True)
mltk_model.trigger_event(MltkModelEvent.TRAIN_SHUTDOWN, results=results, logger=logger)
return results
def _get_epochs(
mltk_model: MltkModel,
epochs:int,
callbacks:list,
logger: logging.Logger
) -> int:
"""Update the training epochs as necessary"""
if mltk_model.test_mode_enabled:
logger.info('Forcing epochs=3 since test=true')
return 3
if epochs == -1:
if not contains_class_type(callbacks, keras.callbacks.EarlyStopping):
raise Exception('If mltk_model.epochs = -1 then mltk_model.early_stopping must be specified')
logger.warning('***')
logger.warning('*** NOTE: Setting training epochs to large value since the EarlyStopping callback is being used')
logger.warning('***')
epochs = 99999
return epochs
def _get_keras_callbacks(
mltk_model: MltkModel,
epochs: int,
logger: logging.Logger
) -> list:
"""Populate the Keras training callbacks"""
keras_callbacks = []
keras_callbacks.extend(mltk_model.train_callbacks)
if mltk_model.tensorboard and not contains_class_type(keras_callbacks, keras.callbacks.TensorBoard):
tb_log_dir = mltk_model.create_log_dir('train/tensorboard')
kwargs = dict(log_dir=tb_log_dir)
kwargs.update(mltk_model.tensorboard)
logger.debug('Using default TensorBoard callback with following parameters:')
logger.debug(f'{pprint.pformat(kwargs)}')
cb = keras.callbacks.TensorBoard(**kwargs)
keras_callbacks.append(cb)
logger.info(f'Tensorboard logdir: {tb_log_dir}')
if mltk_model.checkpoint and not contains_class_type(keras_callbacks, keras.callbacks.ModelCheckpoint):
weights_dir = mltk_model.weights_dir
weights_file_format = mltk_model.weights_file_format
kwargs = dict(
filepath=f'{weights_dir}/{weights_file_format}',
)
kwargs.update(mltk_model.checkpoint)
logger.debug('Using default ModelCheckpoint callback with following parameters:')
logger.debug(f'{pprint.pformat(kwargs)}')
cb = keras.callbacks.ModelCheckpoint(**kwargs)
keras_callbacks.append(cb)
if mltk_model.lr_schedule and not contains_class_type(keras_callbacks, keras.callbacks.LearningRateScheduler):
kwargs = dict()
kwargs.update(mltk_model.lr_schedule)
logger.debug('Using default LearningRateScheduler callback with following parameters:')
logger.debug(f'{pprint.pformat(kwargs)}')
cb = keras.callbacks.LearningRateScheduler(**kwargs)
keras_callbacks.append(cb)
if mltk_model.early_stopping and not contains_class_type(keras_callbacks, keras.callbacks.EarlyStopping):
kwargs = dict()
kwargs.update(mltk_model.early_stopping)
logger.debug('Using default EarlyStopping callback with following parameters:')
logger.debug(f'{pprint.pformat(kwargs)}')
cb = keras.callbacks.EarlyStopping(**kwargs)
keras_callbacks.append(cb)
if mltk_model.reduce_lr_on_plateau and not contains_class_type(keras_callbacks, keras.callbacks.ReduceLROnPlateau):
kwargs = dict()
kwargs.update(mltk_model.reduce_lr_on_plateau)
logger.debug('Using default ReduceLROnPlateau callback with following parameters:')
logger.debug(f'{pprint.pformat(kwargs)}')
cb = keras.callbacks.ReduceLROnPlateau(**kwargs)
keras_callbacks.append(cb)
if mltk_model.checkpoints_enabled:
logger.debug('Enabling model checkpoints')
keras_callbacks.append(keras.callbacks.ModelCheckpoint(
filepath=mltk_model.checkpoints_dir + '/weights-{epoch:03d}.h5',
save_weights_only=True,
save_best_only=False,
save_freq='epoch',
))
mltk_model.trigger_event(
MltkModelEvent.POPULATE_TRAIN_CALLBACKS,
keras_callbacks=keras_callbacks,
logger=logger
)
callback_str = ', '.join([str(x.__class__.__name__) for x in keras_callbacks])
logger.debug(f'Using Keras callbacks: {callback_str}')
return keras_callbacks
def _try_resume_training(
mltk_model: MltkModel,
keras_model: KerasModel,
epochs:int,
resume_epoch: int,
logger: logging.Logger
) -> int:
"""Attempt to resume training at either the last available epoch or at the specified epoch
Return initial_epoch
"""
if resume_epoch == 0:
return 0
# If the --resume <epoch< option was supplied
# then resume at the given checkpoint
if resume_epoch+1 >= epochs:
raise Exception(f'The resume epoch ({resume_epoch}+1) is greater than the max training epochs ({epochs})')
checkpoint_path = mltk_model.get_checkpoint_path(resume_epoch)
if checkpoint_path is None:
if resume_epoch == -1:
logger.warning('No training checkpoints found, cannot --resume. Starting from beginning')
return 0
raise Exception(f'Checkpoint not found, cannot resume training at epoch {resume_epoch}')
fn = os.path.basename(checkpoint_path[:-len('.h5')])
checkpoint_epoch = int(fn.split('-')[1])
try:
logger.info(f'Loading checkpoint weights: {checkpoint_path}')
keras_model.load_weights(checkpoint_path)
except Exception as e:
prepend_exception_msg(e, f'Failed to load checkpoint weights: {checkpoint_path}')
raise
logger.warning(f'Resuming training at epoch {checkpoint_epoch+1} of {epochs}')
return checkpoint_epoch
def compute_class_weights(
mltk_model: MltkModel,
logger: logging.Logger
) -> dict:
try:
class_weights = _compute_class_weights_unsafe(mltk_model, logger=logger)
except Exception as e:
class_weights = None
logger.warning("Failed to compute class weights\nSet my_model.class_weights = 'none' to disable", exc_info=e)
return class_weights
def _compute_class_weights_unsafe(
mltk_model: MltkModel,
logger: logging.Logger
) -> dict:
"""Compute the given data's class weights"""
def _create_weights_dict():
class_weights = mltk_model.class_weights
if not class_weights:
return None
# If a dictionary where the keys directly map to the class ids was given
# then just return the class_weights as-is
if isinstance(class_weights, dict):
if isinstance(list(class_weights.keys())[0], int):
return class_weights
# Otherwise, we need to convert the class weights from:
# {"label1": 1.0, "label2": .5, "lable3": .4}
# to
# {0: 1.0, 1: .5, 2: .4}
try:
class_ids = [x for x in range(len(mltk_model.classes))]
except Exception as e:
prepend_exception_msg(e,
'Class weights should be a dict with each key be an integer corresponding to a class'
)
raise
if isinstance(class_weights, list):
return dict(zip(class_ids, class_weights))
if isinstance(class_weights, str):
class_weights = class_weights.lower()
if class_weights == 'none':
return None
if class_weights not in ('balance', 'balanced'):
raise RuntimeError(f'Invalid my_model.class_weights argument given: {class_weights}')
if hasattr(mltk_model, 'class_counts'):
class_counts = mltk_model.class_counts
if 'training' in class_counts:
class_counts = class_counts['training']
n_samples = sum(x for x in class_counts.values())
if n_samples > 0:
n_classes = mltk_model.n_classes
weights = []
for class_name in mltk_model.classes:
weights.append(n_samples / (n_classes * class_counts[class_name]))
return dict(zip(class_ids, weights))
y = mltk_model.y
if y is not None:
weights = compute_class_weight(class_weights, classes=class_ids, y=y)
return dict(zip(class_ids, weights))
raise RuntimeError(
'my_model.class_weights=balanced not supported if my_model.y or mltk_model.class_counts not provided. \n'
'Must manually set class weights in my_model.class_weights'
)
if isinstance(class_weights, dict):
weights = {}
for class_id, class_name in enumerate(mltk_model.classes):
if class_name not in class_weights:
raise Exception(f'Class {class_name} not found in class weights')
weights[class_id] = class_weights[class_name]
return weights
raise RuntimeError('Unsupported my_model.class_weight format')
class_weights = _create_weights_dict()
if class_weights:
try:
s = 'Class weights:\n'
max_len = max([len(x) for x in mltk_model.classes])
for class_id, class_name in enumerate(mltk_model.classes):
s += f'{class_name.rjust(max_len)} = {class_weights[class_id]:.2f}\n'
logger.info(s[:-1])
except:
logger.info(f'Class weights: {pprint.pformat(class_weights)}')
return class_weights
def _save_keras_model_file(
mltk_model:MltkModel,
keras_model:KerasModel,
logger: logging.Logger
) -> KerasModel:
"""Save the Keras .h5 model file"""
keras_model_dict = dict(value=keras_model)
mltk_model.trigger_event(
MltkModelEvent.BEFORE_SAVE_TRAIN_MODEL,
keras_model=keras_model,
keras_model_dict=keras_model_dict,
logger=logger
)
keras_model = keras_model_dict['value']
# If a custom model saving callback was given then invoke that now
# So that we obtain the correct keras model
if mltk_model.on_save_keras_model is not None:
try:
keras_model = mltk_model.on_save_keras_model(
mltk_model=mltk_model,
keras_model=keras_model,
logger=logger
)
if keras_model is None:
raise RuntimeError('my_model.on_save_keras_model must return a keras model instance')
except Exception as e:
prepend_exception_msg(e, 'Error while saving model using my_model.on_save_keras_model')
raise
# Save the keras model as a .h5 file
try:
h5_path = mltk_model.h5_log_dir_path
logger.info(f'Generating {h5_path}')
keras_model.save(h5_path, save_format='tf')
except Exception as e:
prepend_exception_msg(e, f'Error while saving model to {h5_path}')
raise
keras_model_dict = dict(value=keras_model)
mltk_model.trigger_event(
MltkModelEvent.AFTER_SAVE_TRAIN_MODEL,
keras_model=keras_model,
keras_model_dict=keras_model_dict,
logger=logger
)
keras_model = keras_model_dict['value']
return keras_model
def _save_training_results(
mltk_model:MltkModel,
keras_model:KerasModel,
training_history,
logger: logging.Logger,
show:bool = False
) -> TrainingResults:
"""Save the training history as .json and .png"""
output_dir = f'{mltk_model.log_dir}/train'
results = TrainingResults(mltk_model, keras_model, training_history)
mltk_model.trigger_event(
MltkModelEvent.BEFORE_SAVE_TRAIN_RESULTS,
keras_model=keras_model,
results=results,
output_dir=output_dir,
logger=logger
)
metric, best_val = results.get_best_metric()
logger.info(f'\n\n*** Best training {metric} = {best_val:.3f}\n\n')
try:
history_json_path = f'{mltk_model.log_dir}/train/training-history.json'
logger.debug(f'Generating {history_json_path}')
with open(history_json_path, 'w') as f:
json.dump(results.asdict(), f, indent=2)
except Exception as e:
logger.warning(f'Error while saving training results to {history_json_path}, err: {e}')
# See https://github.com/keras-team/keras/blob/master/keras/losses.py
supported_metrics = {}
supported_metrics['accuracy'] = 'Accuracy'
supported_metrics['loss'] = 'Loss'
supported_metrics['mse'] = 'Mean Square Error'
supported_metrics['mae'] = 'Mean Absolute Error'
supported_metrics['mape'] = 'Mean Absolute Percentage Error'
supported_metrics['msle '] = 'Mean Square Logarithmic Error'
supported_metrics['bce '] = 'Binary Cross-entropy'
supported_metrics['cce'] = 'Categorical Cross-entropy'
found_metrics = []
history = results.history
for metric, value in history.items():
if not metric in supported_metrics:
continue
if not f'val_{metric}' in history:
continue
found_metrics.append(dict(
name=metric,
train=value,
validation=history[f'val_{metric}'],
))
fig, _ = plt.subplots(figsize=(6, 6), clear=True)
fig.suptitle(f'{mltk_model.name} Training History')
# Plot training and validation metrics
for i, metric in enumerate(found_metrics):
plt.subplot(len(found_metrics), 1, i + 1)
plt.plot(metric['train'])
plt.plot(metric['validation'])
plt.title(f'{supported_metrics[metric["name"]]}')
plt.ylabel(supported_metrics[metric['name']])
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.subplots_adjust(hspace=.5)
training_results_path = f'{mltk_model.log_dir}/train/training-history.png'
logger.debug(f'Generating {training_results_path}')
plt.savefig(training_results_path)
if show:
plt.show(block=False)
else:
fig.clear()
plt.close(fig)
mltk_model.trigger_event(
MltkModelEvent.AFTER_SAVE_TRAIN_RESULTS,
keras_model=keras_model,
results=results,
output_dir=output_dir,
logger=logger
)
return results
def _create_model_archive(
mltk_model: MltkModel,
logger: logging.Logger
):
logger.info(f'Creating {mltk_model.archive_path}')
mltk_model.trigger_event(
MltkModelEvent.BEFORE_SAVE_TRAIN_ARCHIVE,
archive_path=mltk_model.archive_path,
logger=logger
)
try:
mltk_model.add_archive_dir('.', create_new=True)
mltk_model.add_archive_file('__mltk_model_spec__')
mltk_model.add_archive_dir('train')
mltk_model.add_archive_dir('dataset', recursive=True)
except Exception as e:
logger.warning(f'Failed to generate model archive, err: {e}', exc_info=e)
mltk_model.trigger_event(
MltkModelEvent.AFTER_SAVE_TRAIN_ARCHIVE,
archive_path=mltk_model.archive_path,
logger=logger
)
def _clear_log_directory(
mltk_model: MltkModel,
logger: logging.Logger,
recursive=False
):
"""Clear any previous training logs"""
training_log_dir = f'{mltk_model.log_dir}/train'
if recursive:
logger.info(f'Cleaning all files in {mltk_model.log_dir}')
clean_directory(mltk_model.log_dir)
elif os.path.exists(training_log_dir):
for fn in os.listdir(training_log_dir):
path = f'{training_log_dir}/{fn}'
if os.path.isfile(path):
logger.debug(f'Removing {path}')
try:
os.remove(path)
except Exception as e:
logger.debug(f'Failed to remove {path}, err: {e}')
if os.path.exists(mltk_model.log_dir):
for fn in os.listdir(mltk_model.log_dir):
path = f'{mltk_model.log_dir}/{fn}'
if os.path.isfile(path):
logger.debug(f'Removing {path}')
try:
os.remove(path)
except Exception as e:
logger.debug(f'Failed to remove {path}, err: {e}')
if os.path.exists(mltk_model.archive_path):
logger.debug(f'Removing {mltk_model.archive_path}')
try:
os.remove(mltk_model.archive_path)
except Exception as e:
logger.debug(f'Failed to remove {path}, err: {e}')