[docs]@MltkModelAttributesDecorator()classMltkModel(ArchiveMixin,TfliteModelMetadataMixin,TfliteModelParametersMixin):"""The root MLTK Model object This must be defined in a model specification file. Refer to the `Model Specification <https://siliconlabs.github.io/mltk/docs/guides/model_specification.html>`_ guide fore more details. """
[docs]def__init__(self,model_script_path:str=None):self._model_script_path=model_script_path# If not model script path was provided# then attempt to automatically determine it from the# filepath that instantiated this MltkModel object# i.e The model script path is the path of the script that created this MltkModel objectifself._model_script_pathisNone:try:call_stack=inspect.stack()self._model_script_path=call_stack[1].filename.replace('\\','/')except:passself._cli=typer.Typer(context_settings=dict(max_content_width=100),add_completion=False)self._attributes=MltkModelAttributes()self._event_handlers:Dict[MltkModelEvent,List[_EventHandlerContext]]={}# At this point, no model properties have been registered# See load_mltk_model_with_path() for the AFTER_MODEL_LOAD eventself.trigger_event(MltkModelEvent.BEFORE_MODEL_LOAD)
@propertydefattributes(self)->MltkModelAttributes:"""Return all model attributes"""returnself._attributes
[docs]defget_attribute(self,name:str):"""Return attribute value or None if unknown"""returnself._attributes.get_value(name)
@propertydefcli(self)->typer.Typer:"""Custom command CLI This is used to register custom commands The commands may be invoked with: mltk custom <model name> [command args] """returnself._cli@propertydefmodel_specification_path(self)->str:"""Return the absolute path to the model's specification python script"""returnself._attributes['model_specification_path']@propertydefname(self)->str:"""The name of this model, this the filename of the model's Python script. """returnself._attributes['name']@propertydefversion(self)->int:"""The model version, e.g. 3 """returnself._attributes['version']@version.setterdefversion(self,v:int):ifnotisinstance(v,int):raiseValueError('Version must be an integer')self._attributes['version']=v@propertydefdescription(self)->str:"""A description of this model and how it should be use. This is added to the .tflite model flatbuffer "description" field """returnself._attributes['description']@description.setterdefdescription(self,v:str):self._attributes['description']=v@propertydeflog_dir(self)->str:"""Path to directory where logs will be generated """log_dir=self._attributes['log_dir']ifnotlog_dir:self._attributes['log_dir']=create_user_dir(suffix=f'models/{self.name}',)elifnotos.path.exists(log_dir):self._attributes['log_dir']=create_user_dir(base_dir=log_dir,)returnself._attributes['log_dir']@log_dir.setterdeflog_dir(self,v:str):self._attributes['log_dir']=v
[docs]defcreate_log_dir(self,suffix:str='',delete_existing=False)->str:"""Create a directory for storing model log files"""log_dir=create_user_dir(suffix=suffix,base_dir=self.log_dir)ifdelete_existing:clean_directory(log_dir)returnlog_dir
[docs]defcreate_logger(self,name,parent:logging.Logger=None)->logging.Logger:"""Create a logger for this model"""train_log_dir=self.create_log_dir(name)log_file=f'{train_log_dir}/log.txt'train_logger=get_logger(name,log_file=log_file,log_file_mode='w')ifparentisnotNone:train_logger.parent=parenttrain_logger.propagate=Truereturntrain_logger
@propertydefh5_log_dir_path(self)->str:"""Path to .h5 model file that is generated in the log directory at the end of training"""h5_path=f'{self.log_dir}/{self.name}'ifself.test_mode_enabled:h5_path+='.test.h5'else:h5_path+='.h5'returnh5_path@propertydeftflite_log_dir_path(self)->str:"""Path to .tflite model file that is generated in the log directory at the end of training (if quantization is enabled)"""tflite_path=f'{self.log_dir}/{self.name}'ifself.test_mode_enabled:tflite_path+='.test.tflite'else:tflite_path+='.tflite'returntflite_path@propertydefunquantized_tflite_log_dir_path(self)->str:"""Path to unquantized/float32 .tflite model file that is generated in the log directory at the end of training (if enabled)"""tflite_path=f'{self.log_dir}/{self.name}.float32'ifself.test_mode_enabled:tflite_path+='.test.tflite'else:tflite_path+='.tflite'returntflite_path@propertydefclasses(self)->List[str]:"""Return a list of the class name strings this model expects"""try:returnself._attributes.get_value('*classes')exceptAttributeError:# pylint: disable=raise-missing-fromraiseValueError('Model does not specify the dataset\'s classes.\n''It must either be manually specified, e.g. my_model.classes = ["dog", "cat"] or inherit a mixin that supports an classes, e.g.: ImageDatasetMixin')@classes.setterdefclasses(self,v:List[str]):try:self._attributes.set_value('*classes',v)exceptAttributeError:self._attributes.register('dataset.classes',dtype=(list,tuple))self._attributes.set_value('dataset.classes',v)@propertydefn_classes(self)->int:"""Return the number of classes this model expects"""returnlen(self.classes)@propertydefinput_shape(self)->Tuple[int]:"""Return the image input shape as a tuple of integers"""try:returnself._attributes.get_value('*input_shape')exceptAttributeError:# pylint: disable=raise-missing-fromraiseValueError('Model does not specify the dataset\'s input_shape.\n''It must either be manually specified, e.g. my_model.input_shape = (96,96,3) or inherit a mixin that supports an input_shape, e.g.: ImageDatasetMixin')@input_shape.setterdefinput_shape(self,v:Tuple[int]):try:self._attributes.set_value('*input_shape',v)exceptAttributeError:self._attributes.register('dataset.input_shape',dtype=(list,tuple))self._attributes.set_value('dataset.input_shape',v)@propertydefkeras_custom_objects(self)->dict:"""Get/set custom objects that should be loaded with the Keras model See https://keras.io/guides/serialization_and_saving/#custom-objects for more details. """returnself._attributes.get_value('keras_custom_objects',default={})@keras_custom_objects.setterdefkeras_custom_objects(self,v:dict):self._attributes['keras_custom_objects']=v@propertydeftest_mode_enabled(self)->bool:"""Return if testing mode has been enabled"""returnself._attributes['test_mode_enabled']
[docs]defenable_test_mode(self):"""Enable testing mode"""self._attributes['test_mode_enabled']=Trueget_mltk_logger().info('Enabling test mode')self.log_dir=f'{self.log_dir}-test'
[docs]defsummary(self)->str:"""Return a summary of the model"""s=f'Name: {self.name}\n's+=f'Version: {self.version}\n's+=f'Description: {self.description}\n'params=self.model_parametersexclude_params=['name','version','classes','runtime_memory_size']try:classes=self.classesexcept:if'classes'inparams:classes=params['classes']else:classes=Noneifclasses:classes=', '.join(classes)s+=f'Classes: {classes}\n'try:input_shape='x'.join(self.input_shape)s+=f'Input shape: {input_shape}\n'except:passtry:dataset=self.dataset# pylint: disable=no-memberifisinstance(dataset,str):s+=f'Dataset: {dataset}\n'except:passif'runtime_memory_size'inparamsandparams['runtime_memory_size']:s+=f'Runtime memory size (RAM): {format_units(params["runtime_memory_size"])}\n'forkey,valueinparams.items():if(keyin('hash','date')andnotvalue)orkeyinexclude_params:continues+=f'{key}: {value}\n'returns.strip()
[docs]defadd_event_handler(self,event:MltkModelEvent,handler:Callable[[MltkModel,logging.Logger,Any],None],_raise_exception=False,**kwargs):"""Register an event handler Register a handler that will be invoked on the corresponding :py:class:`~MltkModelEvent` The given handler should have the signature: .. highlight:: python .. code-block:: python def my_event_handler(mltk_model:MltkModel, logger:logging.Logger, **kwargs): ... Where ``kwargs`` will contain keyword arguments specific to the :py:class:`~MltkModelEvent` as well as any keyword arguments passed to this API. .. note:: - By default, exceptions raised by the handler will be logged, but not stop further execution Use ``_raise_exception``` to cause the handler to raise the exception. - Event handlers are invoked in the order in which they are registered Args: event: The :py:class:`~MltkModelEvent` on which the handler will be invoked handler: Function to be invoked for the given event _raise_exception: By default, exceptions are only logged. This allows for raising the handler exception. kwargs: Additional keyword arguments to provided to the ``handler`` NOTE: These keyword arguments must not collide with the event-specific keyword arguments """ifeventnotinself._event_handlers:self._event_handlers[event]=[]self._event_handlers[event].append(_EventHandlerContext(handler=handler,raise_exception=_raise_exception,kwargs=kwargs))
[docs]deftrigger_event(self,event:MltkModelEvent,**kwargs):"""Trigger all handlers for the given event This is used internally by the MLTK. """from.model_utilsimportpush_active_model,pop_active_modelget_mltk_logger().debug(f'Model event: {event}')ifeventin(MltkModelEvent.TRAIN_STARTUP,MltkModelEvent.EVALUATE_STARTUP,MltkModelEvent.QUANTIZE_STARTUP,):push_active_model(self)elifeventin(MltkModelEvent.TRAIN_SHUTDOWN,MltkModelEvent.EVALUATE_SHUTDOWN,MltkModelEvent.QUANTIZE_SHUTDOWN):pop_active_model()ifeventnotinself._event_handlers:returnlogger=kwargs.pop('logger',get_mltk_logger())forcontextinself._event_handlers[event]:try:context.handler(mltk_model=self,logger=logger,**kwargs,**context.kwargs)exceptExceptionase:ifnotcontext.raise_exception:get_mltk_logger().warning(f'Model event: {event}, handler: {context.handler}, failed, err: {e}',exc_info=e)else:prepend_exception_msg(e,f'Model event: {event}, handler: {context.handler}, failed')raise
def__setattr__(self,name,value):ifnotname.startswith('_')andnotself.has_attribute(name):raiseAttributeError(f'MltkModel does not have the attribute: {name}')object.__setattr__(self,name,value)defhas_attribute(self,name):ifnameinself._attributes:returnTrueforkey,_ininspect.getmembers(self.__class__,lambdax:isinstance(x,property)):ifkey==name:returnTruereturnFalsedef__str__(self):s=f'Name: {self.name}\n's+=f'Version: {self.version}\n's+=f'Description: {self.description}'returnsdef_register_attributes(self):ifself._model_script_path:# By default, the model name is the model file's filenamemodel_name=os.path.basename(self._model_script_path)idx=model_name.rfind('.')ifidx!=-1:model_name=model_name[:idx]else:model_name='my_model'self._attributes.register('model_specification_path',self._model_script_path,dtype=str)self._attributes.register('description','Generated by Silicon Lab\'s MLTK Python package',dtype=str)self._attributes.register('log_dir','',dtype=str)self._attributes.register('name',model_name,dtype=str)self._attributes.register('version',1,dtype=int)self._attributes.register('test_mode_enabled',False,dtype=bool)self._attributes.register('keras_custom_objects',dtype=dict)
Important: We use cookies only for functional and traffic analytics.
We DO NOT use cookies for any marketing purposes. By using our site you acknowledge you have read and understood our Cookie Policy.