[docs]deftrain_model(model:Union[MltkModel,str],weights:str=None,epochs:int=None,resume_epoch:int=0,verbose:bool=None,clean:bool=False,quantize:bool=True,create_archive:bool=True,show:bool=False,test:bool=False,post_process:bool=False)->TrainingResults:"""Train a model using Keras and Tensorflow .. seealso:: * `Model Training Guide <https://siliconlabs.github.io/mltk/docs/guides/model_training.html>`_ * `Model Training API Examples <https://siliconlabs.github.io/mltk/mltk/examples/train_model.html>`_ * `KerasModel.fit() <https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit>`_ Args: model: :py:class:`mltk.core.MltkModel` instance, name of MLTK model, path to model specification script(.py) __Note:__ If the model is in "test mode" then the model will train for 1 epoch weights: Optional file path of model weights to load before training epochs: Optional, number of epochs to train model. This overrides the mltk_model.epochs attribute resume_epoch: Optional, resuming training at the given epoch verbose: Optional, Verbosely print to logger while training clean: Optional, Clean the log directory before training quantize: Optional, quantize the model after training successfully completes create_archive: Optional, create an archive (.mltk.zip) of the training results and generated model files show: Optional, show the training results diagram test: Optional, load the model in "test mode" if true. post_process: This allows for post-processing the training results (e.g. uploading to a cloud) if supported by the given MltkModel Returns: The model TrainingResults """ifisinstance(model,MltkModel):mltk_model=modeliftest:mltk_model.enable_test_mode()elifisinstance(model,str):ifmodel.endswith(('.tflite','.h5','.zip')):raiseValueError('Must provide name of MLTK model ''or path model specification script(.py)')mltk_model=load_mltk_model(model,test=test)else:raiseValueError('Must provide MltkModel instance, name of MLTK model, or path to ''model specification script(.py)')ifnotisinstance(mltk_model,TrainMixin):raiseValueError('model argument must be an MltkModel instance that inherits TrainMixin')ifnotisinstance(mltk_model,DatasetMixin):raiseValueError('model argument must be an MltkModel instance that inherits DatasetMixin')logger=get_mltk_logger()mltk_model.trigger_event(MltkModelEvent.TRAIN_STARTUP,post_process=post_process,logger=logger)# Ensure the MltkModel archive is writable before we start trainingifcreate_archive:mltk_model.check_archive_file_is_writable(throw_exception=True)# Clean the log directory if necessaryifcleanisnotNone:_clear_log_directory(mltk_model,logger=logger,recursive=clean)# Create the training loggerlogger=mltk_model.create_logger('train',parent=logger)gpu.initialize(logger=logger)try:mltk_model.load_dataset(subset='training',logger=logger,test=mltk_model.test_mode_enabled)exceptExceptionase:prepend_exception_msg(e,'Failed to load model training dataset')raisemltk_model.trigger_event(MltkModelEvent.BEFORE_BUILD_TRAIN_MODEL,logger=logger)# Build the MLTK model's corresponding Keras modeltry:keras_model=mltk_model.build_model_function(mltk_model)exceptExceptionase:prepend_exception_msg(e,'Failed to build Keras model')raisemltk_model.trigger_event(MltkModelEvent.AFTER_BUILD_TRAIN_MODEL,keras_model=keras_model,logger=logger)# Load the weights into the model if necessarytry:ifweights:weights_path=mltk_model.get_weights_path(weights)logger.info(f'Loading weights file: {weights_path}')keras_model.load_weights(weights_path)exceptExceptionase:prepend_exception_msg(e,'Failed to load weights into Keras model')raise# Generate a summary of the modeltry:summary=summarize_model(mltk_model,built_model=keras_model)logger.info(summary)withopen(f'{mltk_model.log_dir}/{mltk_model.name}.h5.summary.txt','w')asf:f.write(summary)exceptExceptionase:logger.debug(f'Failed to generate model summary, err: {e}',exc_info=e)logger.info(mltk_model.summarize_dataset())epochs=epochsormltk_model.epochscallbacks=_get_keras_callbacks(mltk_model,epochs=epochs,logger=logger)epochs=_get_epochs(mltk_model,epochs=epochs,callbacks=callbacks,logger=logger)initial_epoch=_try_resume_training(mltk_model,keras_model=keras_model,epochs=epochs,resume_epoch=resume_epoch,logger=logger)try:class_weights=compute_class_weights(mltk_model,logger=logger)exceptExceptionase:class_weights=Nonelogger.warning("Failed to compute class weights\nSet my_model.class_weights = 'none' to disable",exc_info=e)fit_kwargs=dict(x=mltk_model.x,y=mltk_model.y,batch_size=mltk_model.batch_size,steps_per_epoch=mltk_model.steps_per_epoch,validation_split=mltk_model.validation_split,validation_data=mltk_model.validation_data,validation_steps=mltk_model.validation_steps,validation_batch_size=mltk_model.validation_batch_size,validation_freq=mltk_model.validation_freq,shuffle=mltk_model.shuffle,class_weight=class_weights,sample_weight=mltk_model.sample_weight,epochs=epochs,initial_epoch=initial_epoch,callbacks=callbacks,verbose=0ifverboseisFalseelse1,)fit_kwargs.update(mltk_model.train_kwargs)mltk_model.trigger_event(MltkModelEvent.BEFORE_TRAIN,fit_kwargs=fit_kwargs,logger=logger)logger.debug(f'Train kwargs:\n{pprint.pformat(fit_kwargs)}')logger.info('Starting model training ...')training_history=keras_model.fit(**fit_kwargs)mltk_model.trigger_event(MltkModelEvent.AFTER_TRAIN,training_history=training_history,logger=logger)try:mltk_model.unload_dataset()exceptExceptionase:logger.debug(f'Failed to unload dataset, err: {e}',exc_info=e)keras_model=_save_keras_model_file(mltk_model,keras_model,logger=logger)results=_save_training_results(mltk_model,keras_model,training_history,logger=logger,show=show,)ifcreate_archive:_create_model_archive(mltk_model,logger)# Quantize the trained modelifquantizeandmltk_model.tflite_converter:try:quantize_model(mltk_model,keras_model=results.keras_model,update_archive=create_archive)exceptExceptionase:prepend_exception_msg(e,'Failed to quantize model')raiseifmltk_model.on_training_completeisnotNone:try:mltk_model.on_training_complete(results)exceptExceptionase:logger.warning(f'Exception during on_training_complete() callback, err: {e}',exc_info=e)logger.info('Training complete')logger.info(f'Training logs here: {mltk_model.log_dir}')ifcreate_archive:logger.info(f'Trained model files here: {mltk_model.archive_path}')logger.close()ifshow:plt.show(block=True)mltk_model.trigger_event(MltkModelEvent.TRAIN_SHUTDOWN,results=results,logger=logger)returnresults
def_get_epochs(mltk_model:MltkModel,epochs:int,callbacks:list,logger:logging.Logger)->int:"""Update the training epochs as necessary"""ifmltk_model.test_mode_enabled:logger.info('Forcing epochs=3 since test=true')return3ifepochs==-1:ifnotcontains_class_type(callbacks,keras.callbacks.EarlyStopping):raiseException('If mltk_model.epochs = -1 then mltk_model.early_stopping must be specified')logger.warning('***')logger.warning('*** NOTE: Setting training epochs to large value since the EarlyStopping callback is being used')logger.warning('***')epochs=99999returnepochsdef_get_keras_callbacks(mltk_model:MltkModel,epochs:int,logger:logging.Logger)->list:"""Populate the Keras training callbacks"""keras_callbacks=[]keras_callbacks.extend(mltk_model.train_callbacks)ifmltk_model.tensorboardandnotcontains_class_type(keras_callbacks,keras.callbacks.TensorBoard):tb_log_dir=mltk_model.create_log_dir('train/tensorboard')kwargs=dict(log_dir=tb_log_dir)kwargs.update(mltk_model.tensorboard)logger.debug('Using default TensorBoard callback with following parameters:')logger.debug(f'{pprint.pformat(kwargs)}')cb=keras.callbacks.TensorBoard(**kwargs)keras_callbacks.append(cb)logger.info(f'Tensorboard logdir: {tb_log_dir}')ifmltk_model.checkpointandnotcontains_class_type(keras_callbacks,keras.callbacks.ModelCheckpoint):weights_dir=mltk_model.weights_dirweights_file_format=mltk_model.weights_file_formatkwargs=dict(filepath=f'{weights_dir}/{weights_file_format}',)kwargs.update(mltk_model.checkpoint)logger.debug('Using default ModelCheckpoint callback with following parameters:')logger.debug(f'{pprint.pformat(kwargs)}')cb=keras.callbacks.ModelCheckpoint(**kwargs)keras_callbacks.append(cb)ifmltk_model.lr_scheduleandnotcontains_class_type(keras_callbacks,keras.callbacks.LearningRateScheduler):kwargs=dict()kwargs.update(mltk_model.lr_schedule)logger.debug('Using default LearningRateScheduler callback with following parameters:')logger.debug(f'{pprint.pformat(kwargs)}')cb=keras.callbacks.LearningRateScheduler(**kwargs)keras_callbacks.append(cb)ifmltk_model.early_stoppingandnotcontains_class_type(keras_callbacks,keras.callbacks.EarlyStopping):kwargs=dict()kwargs.update(mltk_model.early_stopping)logger.debug('Using default EarlyStopping callback with following parameters:')logger.debug(f'{pprint.pformat(kwargs)}')cb=keras.callbacks.EarlyStopping(**kwargs)keras_callbacks.append(cb)ifmltk_model.reduce_lr_on_plateauandnotcontains_class_type(keras_callbacks,keras.callbacks.ReduceLROnPlateau):kwargs=dict()kwargs.update(mltk_model.reduce_lr_on_plateau)logger.debug('Using default ReduceLROnPlateau callback with following parameters:')logger.debug(f'{pprint.pformat(kwargs)}')cb=keras.callbacks.ReduceLROnPlateau(**kwargs)keras_callbacks.append(cb)ifmltk_model.checkpoints_enabled:logger.debug('Enabling model checkpoints')keras_callbacks.append(keras.callbacks.ModelCheckpoint(filepath=mltk_model.checkpoints_dir+'/weights-{epoch:03d}.h5',save_weights_only=True,save_best_only=False,save_freq='epoch',))mltk_model.trigger_event(MltkModelEvent.POPULATE_TRAIN_CALLBACKS,keras_callbacks=keras_callbacks,logger=logger)callback_str=', '.join([str(x.__class__.__name__)forxinkeras_callbacks])logger.debug(f'Using Keras callbacks: {callback_str}')returnkeras_callbacksdef_try_resume_training(mltk_model:MltkModel,keras_model:KerasModel,epochs:int,resume_epoch:int,logger:logging.Logger)->int:"""Attempt to resume training at either the last available epoch or at the specified epoch Return initial_epoch """ifresume_epoch==0:return0# If the --resume <epoch< option was supplied# then resume at the given checkpointifresume_epoch+1>=epochs:raiseException(f'The resume epoch ({resume_epoch}+1) is greater than the max training epochs ({epochs})')checkpoint_path=mltk_model.get_checkpoint_path(resume_epoch)ifcheckpoint_pathisNone:ifresume_epoch==-1:logger.warning('No training checkpoints found, cannot --resume. Starting from beginning')return0raiseException(f'Checkpoint not found, cannot resume training at epoch {resume_epoch}')fn=os.path.basename(checkpoint_path[:-len('.h5')])checkpoint_epoch=int(fn.split('-')[1])try:logger.info(f'Loading checkpoint weights: {checkpoint_path}')keras_model.load_weights(checkpoint_path)exceptExceptionase:prepend_exception_msg(e,f'Failed to load checkpoint weights: {checkpoint_path}')raiselogger.warning(f'Resuming training at epoch {checkpoint_epoch+1} of {epochs}')returncheckpoint_epochdefcompute_class_weights(mltk_model:MltkModel,logger:logging.Logger)->dict:try:class_weights=_compute_class_weights_unsafe(mltk_model,logger=logger)exceptExceptionase:class_weights=Nonelogger.warning("Failed to compute class weights\nSet my_model.class_weights = 'none' to disable",exc_info=e)returnclass_weightsdef_compute_class_weights_unsafe(mltk_model:MltkModel,logger:logging.Logger)->dict:"""Compute the given data's class weights"""def_create_weights_dict():class_weights=mltk_model.class_weightsifnotclass_weights:returnNone# If a dictionary where the keys directly map to the class ids was given# then just return the class_weights as-isifisinstance(class_weights,dict):ifisinstance(list(class_weights.keys())[0],int):returnclass_weights# Otherwise, we need to convert the class weights from:# {"label1": 1.0, "label2": .5, "lable3": .4}# to# {0: 1.0, 1: .5, 2: .4}try:class_ids=[xforxinrange(len(mltk_model.classes))]exceptExceptionase:prepend_exception_msg(e,'Class weights should be a dict with each key be an integer corresponding to a class')raiseifisinstance(class_weights,list):returndict(zip(class_ids,class_weights))ifisinstance(class_weights,str):class_weights=class_weights.lower()ifclass_weights=='none':returnNoneifclass_weightsnotin('balance','balanced'):raiseRuntimeError(f'Invalid my_model.class_weights argument given: {class_weights}')ifhasattr(mltk_model,'class_counts'):class_counts=mltk_model.class_countsif'training'inclass_counts:class_counts=class_counts['training']n_samples=sum(xforxinclass_counts.values())ifn_samples>0:n_classes=mltk_model.n_classesweights=[]forclass_nameinmltk_model.classes:weights.append(n_samples/(n_classes*class_counts[class_name]))returndict(zip(class_ids,weights))y=mltk_model.yifyisnotNone:weights=compute_class_weight(class_weights,classes=class_ids,y=y)returndict(zip(class_ids,weights))raiseRuntimeError('my_model.class_weights=balanced not supported if my_model.y or mltk_model.class_counts not provided. \n''Must manually set class weights in my_model.class_weights')ifisinstance(class_weights,dict):weights={}forclass_id,class_nameinenumerate(mltk_model.classes):ifclass_namenotinclass_weights:raiseException(f'Class {class_name} not found in class weights')weights[class_id]=class_weights[class_name]returnweightsraiseRuntimeError('Unsupported my_model.class_weight format')class_weights=_create_weights_dict()ifclass_weights:try:s='Class weights:\n'max_len=max([len(x)forxinmltk_model.classes])forclass_id,class_nameinenumerate(mltk_model.classes):s+=f'{class_name.rjust(max_len)} = {class_weights[class_id]:.2f}\n'logger.info(s[:-1])except:logger.info(f'Class weights: {pprint.pformat(class_weights)}')returnclass_weightsdef_save_keras_model_file(mltk_model:MltkModel,keras_model:KerasModel,logger:logging.Logger)->KerasModel:"""Save the Keras .h5 model file"""keras_model_dict=dict(value=keras_model)mltk_model.trigger_event(MltkModelEvent.BEFORE_SAVE_TRAIN_MODEL,keras_model=keras_model,keras_model_dict=keras_model_dict,logger=logger)keras_model=keras_model_dict['value']# If a custom model saving callback was given then invoke that now# So that we obtain the correct keras modelifmltk_model.on_save_keras_modelisnotNone:try:keras_model=mltk_model.on_save_keras_model(mltk_model=mltk_model,keras_model=keras_model,logger=logger)ifkeras_modelisNone:raiseRuntimeError('my_model.on_save_keras_model must return a keras model instance')exceptExceptionase:prepend_exception_msg(e,'Error while saving model using my_model.on_save_keras_model')raise# Save the keras model as a .h5 filetry:h5_path=mltk_model.h5_log_dir_pathlogger.info(f'Generating {h5_path}')keras_model.save(h5_path,save_format='tf')exceptExceptionase:prepend_exception_msg(e,f'Error while saving model to {h5_path}')raisekeras_model_dict=dict(value=keras_model)mltk_model.trigger_event(MltkModelEvent.AFTER_SAVE_TRAIN_MODEL,keras_model=keras_model,keras_model_dict=keras_model_dict,logger=logger)keras_model=keras_model_dict['value']returnkeras_modeldef_save_training_results(mltk_model:MltkModel,keras_model:KerasModel,training_history,logger:logging.Logger,show:bool=False)->TrainingResults:"""Save the training history as .json and .png"""output_dir=f'{mltk_model.log_dir}/train'results=TrainingResults(mltk_model,keras_model,training_history)mltk_model.trigger_event(MltkModelEvent.BEFORE_SAVE_TRAIN_RESULTS,keras_model=keras_model,results=results,output_dir=output_dir,logger=logger)metric,best_val=results.get_best_metric()logger.info(f'\n\n*** Best training {metric} = {best_val:.3f}\n\n')try:history_json_path=f'{mltk_model.log_dir}/train/training-history.json'logger.debug(f'Generating {history_json_path}')withopen(history_json_path,'w')asf:json.dump(results.asdict(),f,indent=2)exceptExceptionase:logger.warning(f'Error while saving training results to {history_json_path}, err: {e}')# See https://github.com/keras-team/keras/blob/master/keras/losses.pysupported_metrics={}supported_metrics['accuracy']='Accuracy'supported_metrics['loss']='Loss'supported_metrics['mse']='Mean Square Error'supported_metrics['mae']='Mean Absolute Error'supported_metrics['mape']='Mean Absolute Percentage Error'supported_metrics['msle ']='Mean Square Logarithmic Error'supported_metrics['bce ']='Binary Cross-entropy'supported_metrics['cce']='Categorical Cross-entropy'found_metrics=[]history=results.historyformetric,valueinhistory.items():ifnotmetricinsupported_metrics:continueifnotf'val_{metric}'inhistory:continuefound_metrics.append(dict(name=metric,train=value,validation=history[f'val_{metric}'],))fig,_=plt.subplots(figsize=(6,6),clear=True)fig.suptitle(f'{mltk_model.name} Training History')# Plot training and validation metricsfori,metricinenumerate(found_metrics):plt.subplot(len(found_metrics),1,i+1)plt.plot(metric['train'])plt.plot(metric['validation'])plt.title(f'{supported_metrics[metric["name"]]}')plt.ylabel(supported_metrics[metric['name']])plt.xlabel('Epoch')plt.legend(['Train','Test'],loc='upper left')plt.subplots_adjust(hspace=.5)training_results_path=f'{mltk_model.log_dir}/train/training-history.png'logger.debug(f'Generating {training_results_path}')plt.savefig(training_results_path)ifshow:plt.show(block=False)else:fig.clear()plt.close(fig)mltk_model.trigger_event(MltkModelEvent.AFTER_SAVE_TRAIN_RESULTS,keras_model=keras_model,results=results,output_dir=output_dir,logger=logger)returnresultsdef_create_model_archive(mltk_model:MltkModel,logger:logging.Logger):logger.info(f'Creating {mltk_model.archive_path}')mltk_model.trigger_event(MltkModelEvent.BEFORE_SAVE_TRAIN_ARCHIVE,archive_path=mltk_model.archive_path,logger=logger)try:mltk_model.add_archive_dir('.',create_new=True)mltk_model.add_archive_file('__mltk_model_spec__')mltk_model.add_archive_dir('train')mltk_model.add_archive_dir('dataset',recursive=True)exceptExceptionase:logger.warning(f'Failed to generate model archive, err: {e}',exc_info=e)mltk_model.trigger_event(MltkModelEvent.AFTER_SAVE_TRAIN_ARCHIVE,archive_path=mltk_model.archive_path,logger=logger)def_clear_log_directory(mltk_model:MltkModel,logger:logging.Logger,recursive=False):"""Clear any previous training logs"""training_log_dir=f'{mltk_model.log_dir}/train'ifrecursive:logger.info(f'Cleaning all files in {mltk_model.log_dir}')clean_directory(mltk_model.log_dir)elifos.path.exists(training_log_dir):forfninos.listdir(training_log_dir):path=f'{training_log_dir}/{fn}'ifos.path.isfile(path):logger.debug(f'Removing {path}')try:os.remove(path)exceptExceptionase:logger.debug(f'Failed to remove {path}, err: {e}')ifos.path.exists(mltk_model.log_dir):forfninos.listdir(mltk_model.log_dir):path=f'{mltk_model.log_dir}/{fn}'ifos.path.isfile(path):logger.debug(f'Removing {path}')try:os.remove(path)exceptExceptionase:logger.debug(f'Failed to remove {path}, err: {e}')ifos.path.exists(mltk_model.archive_path):logger.debug(f'Removing {mltk_model.archive_path}')try:os.remove(mltk_model.archive_path)exceptExceptionase:logger.debug(f'Failed to remove {path}, err: {e}')
Important: We use cookies only for functional and traffic analytics.
We DO NOT use cookies for any marketing purposes. By using our site you acknowledge you have read and understood our Cookie Policy.