Source code for mltk.core.update_model_parameters

import os
import copy
from typing import Union

from mltk.utils.path import fullpath
from mltk.utils.hasher import generate_hash
from mltk.utils.string_formatting import iso_time_str
from mltk.core import (TfliteModel, TfliteModelParameters, TFLITE_METADATA_TAG)
from .model import (
from .utils import get_mltk_logger

[docs]def update_model_parameters( model:Union[MltkModel,TfliteModel,str], params:dict=None, description:str=None, output:str=None, accelerator:str=None, )-> Union[str,TfliteModel]: """Update the parameters of a previously trained model This updates the metadata of a previously trained `.tflite` model. The parameters are taken from either the given :py:class:`mltk.core.MltkModel`'s python script or the given "params" dictionary and added to the `.tflite` model file. .. note:: The `.tflite` metadata is only modified. The weights and model structure of the `.tflite` file are NOT modified. .. seealso:: * `Model Parameters Guide <>`_ * `Update model parameters API examples <>`_ Args: model: Either the name of a model a :py:class:`mltk.core.MltkModel` or :py:class:`mltk.core.TfliteModel` instance or the path to a `.tflite` model file or `` model archive params: Optional dictionary of parameters to add `.tflite`. If omitted then model argument must be a :py:class:`mltk.core.MltkModel` instance or model name description: Optional description to add to `.tflite` output: Optional, directory path or file path to generated `.tflite` file. If none then generate in model log directory. If output='tflite_model', then return the :py:class:`mltk.core.TfliteModel` object instead of `.tflite` file path accelerator: Optional hardware accelerator to use when determining the ``runtime_memory_size`` parameter. If None then default to the CMSIS kernels for calculating the required tensor arena size. Returns: The file path to the generated `.tflite` OR TfliteModel object if output=`tflite_model` """ update_archive = False mltk_model = None tflite_model = None tflite_path = None logger = get_mltk_logger() if params and not isinstance(params, dict): raise Exception('params argument must be a dictionary') if isinstance(model, TfliteModel) or (isinstance(model, str) and model.endswith('.tflite')): if params is None: raise Exception('Must provide "param" argument if "model" is a .tflite path or TfliteModel instance') if isinstance(model, str): tflite_path = fullpath(model) tflite_model = TfliteModel.load_flatbuffer_file(tflite_path) else: tflite_model = model tflite_path = tflite_model.path # See if the .tflite already has parameters params_fb = tflite_model.get_metadata(TFLITE_METADATA_TAG) if params_fb is not None: model_parameters = TfliteModelParameters.deserialize(params_fb) if params: model_parameters.update(params) else: model_parameters = TfliteModelParameters(params) elif isinstance(model, MltkModel) or isinstance(model, str): if isinstance(model, MltkModel): mltk_model = model if not output: update_archive = mltk_model.check_archive_file_is_writable() elif model.endswith(''): if params is None: raise Exception('Must provide "param" argument if "model" is a path') archive_path = model mltk_model = load_mltk_model_with_path(model, logger=logger) # Typically, the mltk_model.tflite_archive_path is determined from the path of the mltk_model.model_specification_path # However, in this special case, the model archive path is given # So we need to add a bit of hackery to ensure the mltk_model.tflite_archive_path points to the correct location # by override the mltk_model.model_specification_path mltk_model._attributes._entries['model_specification_path'].value = archive_path.replace('', '.py') # pylint: disable=protected-access mltk_model.check_archive_file_is_writable(throw_exception=True) update_archive = True else: mltk_model = load_mltk_model(model) if not output: update_archive = mltk_model.check_archive_file_is_writable() tflite_path = mltk_model.tflite_archive_path tflite_model = TfliteModel.load_flatbuffer_file(tflite_path) description = mltk_model.description mltk_model.populate_model_parameters() model_parameters = mltk_model.model_parameters # If additional parameters were given to this API # then add them to the model now if params: model_parameters.update(params) else: raise RuntimeError('model argument must be model name, MltkModel or TfliteModel instance, or path to .tflite or') # Determine what the return value should be: # # If an output argument was supplied if output: # If we should just return the TfliteModel instance if output == 'tflite_model': retval = 'tflite_model' # If the path to a .tflite was given then just use that elif output.endswith('.tflite'): retval = output # Else we were given a directory # If the input was an MltkModel instance or name elif mltk_model is not None: # The retval is the given output directory # plus the model name and .tflite extension if mltk_model.test_mode_enabled: retval = f'{output}/{}.test.tflite' else: retval = f'{output}/{}.tflite' # Otherwise the retval is the given output directory # plus input .tflite filename else: retval_name = 'my_model.tflite' if tflite_path is None else os.path.basename(tflite_path) retval = f'{output}/{retval_name}' # Otherwise no output was given # If an MltkModel instance or name was given elif mltk_model is not None: # Then update the model file in its log directory if mltk_model.test_mode_enabled: retval = f'{mltk_model.create_log_dir()}/{}.test.tflite' else: retval = f'{mltk_model.create_log_dir()}/{}.tflite' # Else we just update the input .tflite file elif tflite_path is not None: retval = tflite_path else: raise RuntimeError('Failed to determine output path') # Add the model description if necessary if description: tflite_model.description = description # Add the default parameters to the model's metadata add_default_parameters( tflite_model, model_parameters, forced_params=params, accelerator=accelerator ) if retval == 'tflite_model': retval = tflite_model else: if update_archive:'Updating {mltk_model.archive_path}') mltk_model.add_archive_file('__mltk_model_spec__') mltk_model.add_archive_file(retval) return retval
def add_default_parameters( tflite_model: TfliteModel, model_parameters: TfliteModelParameters, forced_params:dict=None, accelerator:str=None, add_runtime_memory_size=True ): """Add the default parameters to the model's metadata""" model_parameters = copy.deepcopy(model_parameters) forced_params = forced_params or {} forced_runtime_memory_size = forced_params.get('runtime_memory_size', None) forced_date = forced_params.get('date', None) forced_hash = forced_params.get('hash', None) # We generate the .tflite flatbuffer twice # The first time, we generate with a null hash model parameter # and then calculate the hash of the flatbuffer (with the hash null) # We then update the calculated hash and re-generate the flatbuffer. calculated_hash = None for i in range(2): if i == 0: # Try to determine the RAM required by TFLM # and add it to the parameters if add_runtime_memory_size: if forced_runtime_memory_size is None: try: from mltk.core.tflite_micro import TfliteMicro tflm_model = TfliteMicro.load_tflite_model( tflite_model, accelerator=accelerator, runtime_buffer_size=-1 # Set the runtime buffer size to -1 so the optimal size is automatically found ) try: model_parameters['runtime_memory_size'] = tflm_model.details.runtime_memory_size finally: TfliteMicro.unload_model(tflm_model) except: if 'runtime_memory_size' in model_parameters: del model_parameters['runtime_memory_size'] else: model_parameters['runtime_memory_size'] = forced_runtime_memory_size else: if 'runtime_memory_size' in model_parameters: del model_parameters['runtime_memory_size'] # Ensure the model 'hash' and 'date' parameters are cleared # on the first pass # This ensures the generated hash is constant model_parameters['hash'] = '' model_parameters['date'] = '' else: model_parameters['hash'] = forced_hash or calculated_hash model_parameters['date'] = forced_date or iso_time_str() # Serialize the tflite_model_parameters into a flatbuffer serialized_model_parameters = model_parameters.serialize() # Add the params to the .tflite flatbuffer tflite_model.add_metadata(TFLITE_METADATA_TAG, serialized_model_parameters) # If this is the first pass, then calculate the hash # of the .tflite flatbuffer (including all the metadata) if i == 0: calculated_hash = generate_hash(tflite_model.flatbuffer_data)