Source code for mltk.core.summarize_model


import os
import io
import logging
from typing import Union



from mltk.utils.string_formatting import format_units
from mltk.utils.path import fullpath
from mltk.utils.python import append_exception_msg

from .model import (
    MltkModel,
    MltkModelEvent,
    KerasModel,
    load_mltk_model,
    load_tflite_or_keras_model
)


from .utils import ArchiveFileNotFoundError, get_mltk_logger
from .model.metrics import calculate_model_metrics
from .tflite_model import TfliteModel

[docs]def summarize_model( model: Union[str, MltkModel, KerasModel, TfliteModel], tflite:bool=False, build:bool=False, test:bool=False, built_model:Union[KerasModel, TfliteModel]=None ) -> str: """Generate a summary of the given model and return the summary as a string .. seealso:: * `Model Summary Guide <https://siliconlabs.github.io/mltk/docs/guides/model_summary.html>`_ * `Model summary API examples <https://siliconlabs.github.io/mltk/mltk/examples/summarize_model.html>`_ Args: model: Either * a path to a `.tflite`, `.h5`, `.mltk.zip`, `.py` file, * or a :py:class:`mltk.core.MltkModel`, :py:class:`mltk.core.KerasModel`, * or :py:class:`mltk.core.TfliteModel` instance tflite: If true, the return the summary of the corresponding `.tflite model`. If true and model= :py:class:`mltk.core.KerasModel`, this will quantize it into a `.tflite` model build: If true, then generate a `.h5` or `.tflite` by training the given :py:class:`mltk.core.MltkModel` model for 1 epoch. This is useful for summarizing the :py:class:`mltk.core.MltkModel` without fully training the model first test: If true and the model is the name of a MltkModel, then load the MltkModel in testing mode built_model: Optional, previously built :py:class:`mltk.core.KerasModel` or :py:class:`mltk.core.TfliteModel` associated with given :py:class:`mltk.core.MltkModel` Returns: A summary of the given model as a string """ mltk_model = None mltk_model_summary = None tflite_size = None logger = get_mltk_logger() try: mltk_model, built_model = _load_or_build_model( model, tflite=tflite, build=build, built_model=built_model, test=test, logger=logger ) except ArchiveFileNotFoundError as e: append_exception_msg(e, '\nAlternatively, add the --build option to summarize the model without training it first' ) raise if isinstance(built_model, TfliteModel): tflite_size = built_model.flatbuffer_size if mltk_model is None: # If no MLTK model was given (because we directly loaded a .tflite file) # then create a default model and attempt to load the common metadata entries mltk_model = MltkModel() mltk_model.deserialize_tflite_metadata(built_model) if mltk_model is not None: mltk_model_summary = mltk_model.summary() model_metrics = calculate_model_metrics(built_model, logger=logger) summary = '' if isinstance(built_model, KerasModel): string_buffer = io.StringIO() def _writeln(s): string_buffer.write(s + '\n') built_model.summary(print_fn=_writeln) summary += string_buffer.getvalue() else: summary += built_model.summary() summary += '\n' summary += f'Total MACs: {format_units(model_metrics["total_macs"])}\n' summary += f'Total OPs: {format_units(model_metrics["total_ops"])}\n' if mltk_model_summary is not None: summary += f'{mltk_model_summary}\n' if tflite_size: summary += f'.tflite file size: {format_units(tflite_size, precision=1, add_space=False)}B\n' summary = summary.strip() summary_dict = dict(value=summary) if mltk_model is not None: mltk_model.trigger_event( MltkModelEvent.SUMMARIZE_MODEL, summary=summary, summary_dict=summary_dict, logger=logger ) summary = summary_dict['value'] return summary
def _load_or_build_model( model:Union[str, MltkModel, KerasModel, TfliteModel], built_model:Union[KerasModel, TfliteModel], test:bool, tflite:bool, build:bool, logger:logging.Logger ): """Load a previously trained .tflite/.h5 OR build the model now""" mltk_model = None # If a MltkModel instance was given if isinstance(model, MltkModel): mltk_model = model # Elif if a KerasModel instance was given elif isinstance(model, KerasModel): built_model = model # Elif if a KerasModel instance was given elif isinstance(model, TfliteModel): built_model = model elif not isinstance(model, str): raise Exception('model argument must be a string or MltkModel,KerasModel,TfliteModel instance') # Else if the path to a .h5 or .tflite was given elif model.endswith(('.tflite', '.h5')): model_path = fullpath(model) if not os.path.exists(model_path): raise FileNotFoundError(f'Model not found: {model_path}') built_model = load_tflite_or_keras_model(model_path) # Else a MLTK model name was given else: mltk_model = load_mltk_model( model, test=test, print_not_found_err=True ) if build and mltk_model is None: raise Exception('Must provide MltkModel with the build option') # If we have a MltkModel instance but no built model instance if mltk_model is not None and built_model is None: # If we want to build keras or .tflite model from the .tflite # (i.e. if the model has not already been trained) if build: from .train_model import train_model from .quantize_model import quantize_model if tflite: built_model = quantize_model( model=mltk_model, build=True, output='tflite_model' ) else: built_mltk_model = load_mltk_model(mltk_model.model_specification_path, test=True) results = train_model( model=built_mltk_model, epochs=1, quantize=False, clean=None, create_archive=False, verbose=logger.verbose, ) built_model = results.keras_model # Else if we need to load the .tflite from the MltkModel archive elif tflite: # If a .tflite path was given if isinstance(tflite, str): built_model = load_tflite_or_keras_model( tflite, ) # Else load the .tflite from the model archive else: built_model = load_tflite_or_keras_model( mltk_model, model_type='tflite', ) # Else load the .h5 from the MltkModel archive else: built_model = load_tflite_or_keras_model( mltk_model, model_type='h5', ) return mltk_model, built_model