[docs]defupdate_model_parameters(model:Union[MltkModel,TfliteModel,str],params:dict=None,description:str=None,output:str=None,accelerator:str=None,)->Union[str,TfliteModel]:"""Update the parameters of a previously trained model This updates the metadata of a previously trained `.tflite` model. The parameters are taken from either the given :py:class:`mltk.core.MltkModel`'s python script or the given "params" dictionary and added to the `.tflite` model file. .. note:: The `.tflite` metadata is only modified. The weights and model structure of the `.tflite` file are NOT modified. .. seealso:: * `Model Parameters Guide <https://siliconlabs.github.io/mltk/docs/guides/model_parameters.html>`_ * `Update model parameters API examples <https://siliconlabs.github.io/mltk/mltk/examples/update_params.html>`_ Args: model: Either the name of a model a :py:class:`mltk.core.MltkModel` or :py:class:`mltk.core.TfliteModel` instance or the path to a `.tflite` model file or `.mltk.zip` model archive params: Optional dictionary of parameters to add `.tflite`. If omitted then model argument must be a :py:class:`mltk.core.MltkModel` instance or model name description: Optional description to add to `.tflite` output: Optional, directory path or file path to generated `.tflite` file. If none then generate in model log directory. If output='tflite_model', then return the :py:class:`mltk.core.TfliteModel` object instead of `.tflite` file path accelerator: Optional hardware accelerator to use when determining the ``runtime_memory_size`` parameter. If None then default to the CMSIS kernels for calculating the required tensor arena size. Returns: The file path to the generated `.tflite` OR TfliteModel object if output=`tflite_model` """update_archive=Falsemltk_model=Nonetflite_model=Nonetflite_path=Nonelogger=get_mltk_logger()ifparamsandnotisinstance(params,dict):raiseException('params argument must be a dictionary')ifisinstance(model,TfliteModel)or(isinstance(model,str)andmodel.endswith('.tflite')):ifparamsisNone:raiseException('Must provide "param" argument if "model" is a .tflite path or TfliteModel instance')ifisinstance(model,str):tflite_path=fullpath(model)tflite_model=TfliteModel.load_flatbuffer_file(tflite_path)else:tflite_model=modeltflite_path=tflite_model.path# See if the .tflite already has parametersparams_fb=tflite_model.get_metadata(TFLITE_METADATA_TAG)ifparams_fbisnotNone:model_parameters=TfliteModelParameters.deserialize(params_fb)ifparams:model_parameters.update(params)else:model_parameters=TfliteModelParameters(params)elifisinstance(model,MltkModel)orisinstance(model,str):ifisinstance(model,MltkModel):mltk_model=modelifnotoutput:update_archive=mltk_model.check_archive_file_is_writable()elifmodel.endswith('.mltk.zip'):ifparamsisNone:raiseException('Must provide "param" argument if "model" is a .mltk.zip path')archive_path=modelmltk_model=load_mltk_model_with_path(model,logger=logger)# Typically, the mltk_model.tflite_archive_path is determined from the path of the mltk_model.model_specification_path# However, in this special case, the model archive path is given# So we need to add a bit of hackery to ensure the mltk_model.tflite_archive_path points to the correct location# by override the mltk_model.model_specification_pathmltk_model._attributes._entries['model_specification_path'].value=archive_path.replace('.mltk.zip','.py')# pylint: disable=protected-accessmltk_model.check_archive_file_is_writable(throw_exception=True)update_archive=Trueelse:mltk_model=load_mltk_model(model)ifnotoutput:update_archive=mltk_model.check_archive_file_is_writable()tflite_path=mltk_model.tflite_archive_pathtflite_model=TfliteModel.load_flatbuffer_file(tflite_path)description=mltk_model.descriptionmltk_model.populate_model_parameters()model_parameters=mltk_model.model_parameters# If additional parameters were given to this API# then add them to the model nowifparams:model_parameters.update(params)else:raiseRuntimeError('model argument must be model name, MltkModel or TfliteModel instance, or path to .tflite or .mltk.zip')# Determine what the return value should be:## If an output argument was suppliedifoutput:# If we should just return the TfliteModel instanceifoutput=='tflite_model':retval='tflite_model'# If the path to a .tflite was given then just use thatelifoutput.endswith('.tflite'):retval=output# Else we were given a directory# If the input was an MltkModel instance or nameelifmltk_modelisnotNone:# The retval is the given output directory# plus the model name and .tflite extensionifmltk_model.test_mode_enabled:retval=f'{output}/{mltk_model.name}.test.tflite'else:retval=f'{output}/{mltk_model.name}.tflite'# Otherwise the retval is the given output directory# plus input .tflite filenameelse:retval_name='my_model.tflite'iftflite_pathisNoneelseos.path.basename(tflite_path)retval=f'{output}/{retval_name}'# Otherwise no output was given# If an MltkModel instance or name was givenelifmltk_modelisnotNone:# Then update the model file in its log directoryifmltk_model.test_mode_enabled:retval=f'{mltk_model.create_log_dir()}/{mltk_model.name}.test.tflite'else:retval=f'{mltk_model.create_log_dir()}/{mltk_model.name}.tflite'# Else we just update the input .tflite fileeliftflite_pathisnotNone:retval=tflite_pathelse:raiseRuntimeError('Failed to determine output path')# Add the model description if necessaryifdescription:tflite_model.description=description# Add the default parameters to the model's metadataadd_default_parameters(tflite_model,model_parameters,forced_params=params,accelerator=accelerator)ifretval=='tflite_model':retval=tflite_modelelse:tflite_model.save(retval)ifupdate_archive:logger.info(f'Updating {mltk_model.archive_path}')mltk_model.add_archive_file('__mltk_model_spec__')mltk_model.add_archive_file(retval)returnretval
defadd_default_parameters(tflite_model:TfliteModel,model_parameters:TfliteModelParameters,forced_params:dict=None,accelerator:str=None,add_runtime_memory_size=True):"""Add the default parameters to the model's metadata"""model_parameters=copy.deepcopy(model_parameters)forced_params=forced_paramsor{}forced_runtime_memory_size=forced_params.get('runtime_memory_size',None)forced_date=forced_params.get('date',None)forced_hash=forced_params.get('hash',None)# We generate the .tflite flatbuffer twice# The first time, we generate with a null hash model parameter# and then calculate the hash of the flatbuffer (with the hash null)# We then update the calculated hash and re-generate the flatbuffer.calculated_hash=Noneforiinrange(2):ifi==0:# Try to determine the RAM required by TFLM# and add it to the parametersifadd_runtime_memory_size:ifforced_runtime_memory_sizeisNone:try:frommltk.core.tflite_microimportTfliteMicrotflm_model=TfliteMicro.load_tflite_model(tflite_model,accelerator=accelerator,runtime_buffer_size=-1# Set the runtime buffer size to -1 so the optimal size is automatically found)try:model_parameters['runtime_memory_size']=tflm_model.details.runtime_memory_sizefinally:TfliteMicro.unload_model(tflm_model)except:if'runtime_memory_size'inmodel_parameters:delmodel_parameters['runtime_memory_size']else:model_parameters['runtime_memory_size']=forced_runtime_memory_sizeelse:if'runtime_memory_size'inmodel_parameters:delmodel_parameters['runtime_memory_size']# Ensure the model 'hash' and 'date' parameters are cleared# on the first pass# This ensures the generated hash is constantmodel_parameters['hash']=''model_parameters['date']=''else:model_parameters['hash']=forced_hashorcalculated_hashmodel_parameters['date']=forced_dateoriso_time_str()# Serialize the tflite_model_parameters into a flatbufferserialized_model_parameters=model_parameters.serialize()# Add the params to the .tflite flatbuffertflite_model.add_metadata(TFLITE_METADATA_TAG,serialized_model_parameters)# If this is the first pass, then calculate the hash# of the .tflite flatbuffer (including all the metadata)ifi==0:calculated_hash=generate_hash(tflite_model.flatbuffer_data)
Important: We use cookies only for functional and traffic analytics.
We DO NOT use cookies for any marketing purposes. By using our site you acknowledge you have read and understood our Cookie Policy.