from __future__ import annotations
from typing import List, Tuple, Callable, Any, Dict, NamedTuple
import inspect
import os
import logging
import typer
from mltk.utils.path import clean_directory, create_user_dir
from mltk.core.utils import get_mltk_logger
from mltk.utils.logger import get_logger
from mltk.utils.string_formatting import format_units
from mltk.utils.python import prepend_exception_msg
from .mixins.archive_mixin import ArchiveMixin
from .mixins.tflite_model_metadata_mixin import TfliteModelMetadataMixin
from .mixins.tflite_model_parameters_mixin import TfliteModelParametersMixin
from .model_attributes import MltkModelAttributes, MltkModelAttributesDecorator
from .model_event import MltkModelEvent
[docs]@MltkModelAttributesDecorator()
class MltkModel(
ArchiveMixin,
TfliteModelMetadataMixin,
TfliteModelParametersMixin
):
"""The root MLTK Model object
This must be defined in a model specification file.
Refer to the `Model Specification <https://siliconlabs.github.io/mltk/docs/guides/model_specification.html>`_ guide fore more details.
"""
[docs] def __init__(self, model_script_path:str=None):
self._model_script_path = model_script_path
# If not model script path was provided
# then attempt to automatically determine it from the
# filepath that instantiated this MltkModel object
# i.e The model script path is the path of the script that created this MltkModel object
if self._model_script_path is None:
try:
call_stack = inspect.stack()
self._model_script_path = call_stack[1].filename.replace('\\', '/')
except:
pass
self._cli = typer.Typer(
context_settings=dict(
max_content_width=100
),
add_completion=False
)
self._attributes = MltkModelAttributes()
self._event_handlers:Dict[MltkModelEvent, List[_EventHandlerContext]] = {}
# At this point, no model properties have been registered
# See load_mltk_model_with_path() for the AFTER_MODEL_LOAD event
self.trigger_event(MltkModelEvent.BEFORE_MODEL_LOAD)
@property
def attributes(self) -> MltkModelAttributes:
"""Return all model attributes"""
return self._attributes
[docs] def get_attribute(self, name: str):
"""Return attribute value or None if unknown"""
return self._attributes.get_value(name)
@property
def cli(self) -> typer.Typer:
"""Custom command CLI
This is used to register custom commands
The commands may be invoked with:
mltk custom <model name> [command args]
"""
return self._cli
@property
def model_specification_path(self) -> str:
"""Return the absolute path to the model's specification python script"""
return self._attributes['model_specification_path']
@property
def name(self) -> str:
"""The name of this model, this the filename of the model's Python script.
"""
return self._attributes['name']
@property
def version(self) -> int:
"""The model version, e.g. 3
"""
return self._attributes['version']
@version.setter
def version(self, v: int):
if not isinstance(v, int):
raise ValueError('Version must be an integer')
self._attributes['version'] = v
@property
def description(self) -> str:
"""A description of this model and how it should be use.
This is added to the .tflite model flatbuffer "description" field
"""
return self._attributes['description']
@description.setter
def description(self, v: str):
self._attributes['description'] = v
@property
def log_dir(self) -> str:
"""Path to directory where logs will be generated
"""
log_dir = self._attributes['log_dir']
if not log_dir:
self._attributes['log_dir'] = create_user_dir(
suffix=f'models/{self.name}',
)
elif not os.path.exists(log_dir):
self._attributes['log_dir'] = create_user_dir(
base_dir=log_dir,
)
return self._attributes['log_dir']
@log_dir.setter
def log_dir(self, v: str):
self._attributes['log_dir'] = v
[docs] def create_log_dir(self, suffix:str = '', delete_existing=False) -> str:
"""Create a directory for storing model log files"""
log_dir = create_user_dir(
suffix=suffix,
base_dir=self.log_dir
)
if delete_existing:
clean_directory(log_dir)
return log_dir
[docs] def create_logger(self, name, parent: logging.Logger=None) -> logging.Logger:
"""Create a logger for this model"""
train_log_dir = self.create_log_dir(name)
log_file = f'{train_log_dir}/log.txt'
train_logger = get_logger(name, log_file=log_file, log_file_mode='w')
if parent is not None:
train_logger.parent = parent
train_logger.propagate = True
return train_logger
@property
def h5_log_dir_path(self) -> str:
"""Path to .h5 model file that is generated in the log directory at the end of training"""
h5_path = f'{self.log_dir}/{self.name}'
if self.test_mode_enabled:
h5_path += '.test.h5'
else:
h5_path += '.h5'
return h5_path
@property
def tflite_log_dir_path(self) -> str:
"""Path to .tflite model file that is generated in the log directory at the end of training (if quantization is enabled)"""
tflite_path = f'{self.log_dir}/{self.name}'
if self.test_mode_enabled:
tflite_path += '.test.tflite'
else:
tflite_path += '.tflite'
return tflite_path
@property
def unquantized_tflite_log_dir_path(self) -> str:
"""Path to unquantized/float32 .tflite model file that is generated in the
log directory at the end of training (if enabled)"""
tflite_path = f'{self.log_dir}/{self.name}.float32'
if self.test_mode_enabled:
tflite_path += '.test.tflite'
else:
tflite_path += '.tflite'
return tflite_path
@property
def classes(self) -> List[str]:
"""Return a list of the class name strings this model expects"""
try:
return self._attributes.get_value('*classes')
except AttributeError:
# pylint: disable=raise-missing-from
raise ValueError(
'Model does not specify the dataset\'s classes.\n'
'It must either be manually specified, e.g. my_model.classes = ["dog", "cat"] or inherit a mixin that supports an classes, e.g.: ImageDatasetMixin')
@classes.setter
def classes(self, v: List[str]):
try:
self._attributes.set_value('*classes', v)
except AttributeError:
self._attributes.register('dataset.classes', dtype=(list,tuple))
self._attributes.set_value('dataset.classes', v)
@property
def n_classes(self) -> int:
"""Return the number of classes this model expects"""
return len(self.classes)
@property
def input_shape(self) -> Tuple[int]:
"""Return the image input shape as a tuple of integers"""
try:
return self._attributes.get_value('*input_shape')
except AttributeError:
# pylint: disable=raise-missing-from
raise ValueError(
'Model does not specify the dataset\'s input_shape.\n'
'It must either be manually specified, e.g. my_model.input_shape = (96,96,3) or inherit a mixin that supports an input_shape, e.g.: ImageDatasetMixin'
)
@input_shape.setter
def input_shape(self, v: Tuple[int]):
try:
self._attributes.set_value('*input_shape', v)
except AttributeError:
self._attributes.register('dataset.input_shape', dtype=(list,tuple))
self._attributes.set_value('dataset.input_shape', v)
@property
def keras_custom_objects(self) -> dict:
"""Get/set custom objects that should be loaded with the Keras model
See https://keras.io/guides/serialization_and_saving/#custom-objects for more details.
"""
return self._attributes.get_value('keras_custom_objects', default={})
@keras_custom_objects.setter
def keras_custom_objects(self, v: dict):
self._attributes['keras_custom_objects'] = v
@property
def test_mode_enabled(self) -> bool:
"""Return if testing mode has been enabled"""
return self._attributes['test_mode_enabled']
[docs] def enable_test_mode(self):
"""Enable testing mode"""
self._attributes['test_mode_enabled'] = True
get_mltk_logger().info('Enabling test mode')
self.log_dir = f'{self.log_dir}-test'
[docs] def summary(self) -> str:
"""Return a summary of the model"""
s = f'Name: {self.name}\n'
s += f'Version: {self.version}\n'
s += f'Description: {self.description}\n'
params = self.model_parameters
exclude_params = ['name', 'version', 'classes', 'runtime_memory_size']
try:
classes = self.classes
except:
if 'classes' in params:
classes = params['classes']
else:
classes = None
if classes:
classes = ', '.join(classes)
s += f'Classes: {classes}\n'
try:
input_shape = 'x'.join(self.input_shape)
s += f'Input shape: {input_shape}\n'
except:
pass
try:
dataset = self.dataset # pylint: disable=no-member
if isinstance(dataset, str):
s += f'Dataset: {dataset}\n'
except:
pass
if 'runtime_memory_size' in params and params['runtime_memory_size']:
s += f'Runtime memory size (RAM): {format_units(params["runtime_memory_size"])}\n'
for key, value in params.items():
if (key in ('hash', 'date') and not value) or key in exclude_params:
continue
s += f'{key}: {value}\n'
return s.strip()
[docs] def add_event_handler(
self,
event:MltkModelEvent,
handler:Callable[[MltkModel, logging.Logger, Any],None],
_raise_exception=False,
**kwargs
):
"""Register an event handler
Register a handler that will be invoked on the corresponding :py:class:`~MltkModelEvent`
The given handler should have the signature:
.. highlight:: python
.. code-block:: python
def my_event_handler(mltk_model:MltkModel, logger:logging.Logger, **kwargs):
...
Where ``kwargs`` will contain keyword arguments specific to the :py:class:`~MltkModelEvent`
as well as any keyword arguments passed to this API.
.. note::
- By default, exceptions raised by the handler will be logged, but not stop further execution
Use ``_raise_exception``` to cause the handler to raise the exception.
- Event handlers are invoked in the order in which they are registered
Args:
event: The :py:class:`~MltkModelEvent` on which the handler will be invoked
handler: Function to be invoked for the given event
_raise_exception: By default, exceptions are only logged.
This allows for raising the handler exception.
kwargs: Additional keyword arguments to provided to the ``handler``
NOTE: These keyword arguments must not collide with the event-specific keyword arguments
"""
if event not in self._event_handlers:
self._event_handlers[event] = []
self._event_handlers[event].append(
_EventHandlerContext(
handler=handler,
raise_exception=_raise_exception,
kwargs=kwargs
)
)
[docs] def trigger_event(
self,
event:MltkModelEvent,
**kwargs
):
"""Trigger all handlers for the given event
This is used internally by the MLTK.
"""
from .model_utils import push_active_model, pop_active_model
get_mltk_logger().debug(f'Model event: {event}')
if event in (
MltkModelEvent.TRAIN_STARTUP,
MltkModelEvent.EVALUATE_STARTUP,
MltkModelEvent.QUANTIZE_STARTUP,
):
push_active_model(self)
elif event in (
MltkModelEvent.TRAIN_SHUTDOWN,
MltkModelEvent.EVALUATE_SHUTDOWN,
MltkModelEvent.QUANTIZE_SHUTDOWN
):
pop_active_model()
if event not in self._event_handlers:
return
logger = kwargs.pop('logger', get_mltk_logger())
for context in self._event_handlers[event]:
try:
context.handler(
mltk_model=self,
logger=logger,
**kwargs,
**context.kwargs
)
except Exception as e:
if not context.raise_exception:
get_mltk_logger().warning(f'Model event: {event}, handler: {context.handler}, failed, err: {e}', exc_info=e)
else:
prepend_exception_msg(e, f'Model event: {event}, handler: {context.handler}, failed')
raise
def __setattr__(self, name, value):
if not name.startswith('_') and not self.has_attribute(name):
raise AttributeError(f'MltkModel does not have the attribute: {name}')
object.__setattr__(self, name, value)
def has_attribute(self, name):
if name in self._attributes:
return True
for key, _ in inspect.getmembers(self.__class__, lambda x: isinstance(x, property)):
if key == name:
return True
return False
def __str__(self):
s = f'Name: {self.name}\n'
s += f'Version: {self.version}\n'
s += f'Description: {self.description}'
return s
def _register_attributes(self):
if self._model_script_path:
# By default, the model name is the model file's filename
model_name = os.path.basename(self._model_script_path)
idx = model_name.rfind('.')
if idx != -1:
model_name = model_name[:idx]
else:
model_name = 'my_model'
self._attributes.register('model_specification_path', self._model_script_path, dtype=str)
self._attributes.register('description', 'Generated by Silicon Lab\'s MLTK Python package', dtype=str)
self._attributes.register('log_dir', '', dtype=str)
self._attributes.register('name', model_name, dtype=str)
self._attributes.register('version', 1, dtype=int)
self._attributes.register('test_mode_enabled', False, dtype=bool)
self._attributes.register('keras_custom_objects', dtype=dict)
class _EventHandlerContext(NamedTuple):
handler:Callable[[MltkModel,Any],None]
kwargs:dict
raise_exception:bool