Source code for mltk.core.model.mixins.weights_and_biases_mixin

from typing import Optional, Dict, Any
import os
import shutil
import copy
import logging
import warnings



try:
    import wandb
    from wandb.keras import WandbCallback
    from wandb.keras import WandbModelCheckpoint
except:
    class WandbCallback:
        '''Placeholder for https://docs.wandb.ai/ref/python/integrations/keras/wandbcallback'''
    class WandbModelCheckpoint:
        '''Placeholder for https://docs.wandb.ai/ref/python/integrations/keras/wandbmodelcheckpoint'''


from mltk.utils.python import prepend_exception_msg
from mltk.utils.string_formatting import iso_time_filename_str
from mltk.core.utils import get_mltk_logger
from mltk.core.profiling_results import ProfilingModelResults
from mltk.core.evaluation_results import EvaluationResults

from .base_mixin import BaseMixin
from ..model_attributes import MltkModelAttributesDecorator
from ..model_event import MltkModelEvent


[docs]@MltkModelAttributesDecorator() class WeightsAndBiasesMixin(BaseMixin): """Provides various properties to the base :py:class:`~MltkModel` used by `Weights & Biases <https://wandb.ai>`_ 3rd-party cloud backend. .. seealso:: - `Tutorial: Cloud logging with Weights & Biases <https://siliconlabs.github.io/mltk/mltk/tutorials/cloud_logging_with_wandb.html>`_ - `Weights & Biases Documentation <https://docs.wandb.ai>`_ """ @property def wandb_is_initialized(self) -> bool: """Return if the wandb backend is initialized""" return self._attributes.get_value('wandb.is_initialized', default=False) @property def wandb_is_disabled(self) -> bool: """Manually disable the wandb backend""" return self._attributes.get_value('wandb.is_disabled', default=False) @wandb_is_disabled.setter def wandb_is_disabled(self, v:bool): self._attributes['wandb.is_disabled'] = v @property def wandb_init_kwargs(self) -> dict: """Additional arguments to provide to ``wandb.init()`` The following argument are automatically populated by this mixin: - **project** - The name of the :py:attr:`~MltkModel.name` - **job_type** - ``train``, ``evaluation``, ``quantize``, or ``profile`` - **dir** - The :py:attr:`~MltkModel.log_dir` - **id** - The timestamp when the training was invoked. See :py:attr:`~WeightsAndBiasesMixin.wandb_session_id` This id is re-used by the ``evaluate``, ``profile``, and ``quantize`` commands - **resume** - Set to ``never`` for the ``train`` command, and ``must`` otherwise See `wandb.init() <https://docs.wandb.ai/ref/python/init>`_ for the other available arguments """ return self._attributes.get_value('wandb.init_kwargs', default={}) @wandb_init_kwargs.setter def wandb_init_kwargs(self, v:dict): self._attributes['wandb.init_kwargs'] = v @property def wandb_config(self) -> dict: """Additional configuration values This sets the `wandb.config <https://docs.wandb.ai/guides/track/config>`_ object in your script to save your training configuration: hyperparameters, input settings like dataset name or model type, and any other independent variables for your experiments. This is useful for analyzing your experiments and reproducing your work in the future. You'll be able to group by config values in the web interface, comparing the settings of different runs and seeing how these affect the output. """ return self._attributes.get_value('wandb.config', default={}) @wandb_config.setter def wandb_config(self, v:dict): self._attributes['wandb.config'] = v @property def wandb_callback(self) -> WandbCallback: """Keras callback to automatically log info to wandb This allows for specifying a custom `wandb.keras.WandbCallback <https://docs.wandb.ai/ref/python/integrations/keras/wandbcallback>`_. If not set, then the mixin will automatically populate this callback. """ return self._attributes.get_value('wandb.callback', default=None) @wandb_callback.setter def wandb_callback(self, v:WandbCallback): self._attributes['wandb.callback'] = v @property def wandb_model_checkpoint_callback(self) -> WandbModelCheckpoint: """Callback to periodically save the Keras model or model weights This allows for specifying a custom `wandb.keras.WandbModelCheckpoint <https://docs.wandb.ai/ref/python/integrations/keras/wandbmodelcheckpoint>`_. """ return self._attributes.get_value('wandb.model_checkpoint_callback', default=None) @wandb_model_checkpoint_callback.setter def wandb_model_checkpoint_callback(self, v:WandbModelCheckpoint): self._attributes['wandb.model_checkpoint_callback'] = v @property def wandb_session_id(self) -> str: """The wandb project session or run ID This is the timestamp of when the last `train` command was invoked for the model. This ID is re-used for `evaluate`, `profile`, and `quantize` commands. This value is used at the ``id`` argument to `wandb.init() <https://docs.wandb.ai/ref/python/init>`_ """ return self._attributes.get_value('wandb.session_id', default=None)
[docs] def wandb_save( self, glob_str: Optional[str] = None, base_path: Optional[str] = None, policy = "live", logger:logging.Logger = None ): """Save files to wandb cloud Internally, this invokes `wandb.save() <https://docs.wandb.ai/ref/python/save>`_ Args: glob_str: a relative or absolute path to a unix glob or regular path. If this isn't specified the method is a noop. base_path: the base path to run the glob relative to policy: one of ``live``, ``now``, or ``end`` - ``live``: upload the file as it changes, overwriting the previous version - ``now``: upload the file once now - ``end``: only upload file when the run ends """ if not self.wandb_is_initialized: return logger = logger or get_mltk_logger() try: logger.debug(f'Saving to wandb: glob_str={glob_str}, base_path={base_path}') wandb.save( glob_str=glob_str, base_path=base_path, policy=policy ) except Exception as e: get_mltk_logger() .error(f'Failed to save to wandb cloud, glob_str={glob_str}, base_path={base_path}, err: {e}')
[docs] def wandb_log( self, data: Dict[str, Any], step: Optional[int] = None, commit: Optional[bool] = True, logger:logging.Logger = None ): """Logs a dictionary of data to the current wandb run's history Internally, this invokes `wandb.log() <https://docs.wandb.ai/ref/python/log>`_ Args: data: A dict of serializable python objects i.e str, ints, floats, Tensors, dicts, or any of the wandb.data_types commit: Save the metrics dict to the wandb server and increment the step. If false wandb.log just updates the current metrics dict with the data argument and metrics won't be saved until wandb.log is called with commit=True. step: The global step in processing. This persists any non-committed earlier steps but defaults to not committing the specified step. """ assert isinstance(data, dict), 'Data not instance of dict' if not self.wandb_is_initialized: return logger = logger or get_mltk_logger() logger.debug(f'Logging to wandb: data keys={", ".join(data)}') try: wandb.log( data=data, step=step, commit=commit, ) except Exception as e: logger.error(f'Failed to log to wandb cloud, err: {e}')
def _register_attributes(self): """This is called when the MltkModel properties are first registered, see _check_attributes_registered_decorator() """ try: import wandb except: get_mltk_logger().warning('Failed import wandb Python package, try running: pip install wandb plotly') return get_mltk_logger().debug('Registering WeightsAndBiasesMixin') # Register the various properties for this mixin self._attributes.register('wandb.init_kwargs', dtype=dict) self._attributes.register('wandb.callback', dtype=WandbCallback) self._attributes.register('wandb.model_checkpoint_callback', dtype=WandbModelCheckpoint) self._attributes.register('wandb.config', dtype=dict) self._attributes.register('wandb.session_id', dtype=str) self._attributes.register('wandb.is_initialized', dtype=bool) self._attributes.register('wandb.is_disabled', dtype=bool) # Register the various model event handlers self.add_event_handler( MltkModelEvent.AFTER_MODEL_LOAD, self._wandb_load ) self.add_event_handler( MltkModelEvent.TRAIN_STARTUP, self._wandb_init, job_type='train' ) self.add_event_handler( MltkModelEvent.SUMMARIZE_MODEL, self._wandb_log_model_summary, ) self.add_event_handler( MltkModelEvent.SUMMARIZE_DATASET, self._wandb_log_dataset_summary, ) self.add_event_handler( MltkModelEvent.EVALUATE_STARTUP, self._wandb_init, job_type='evaluate' ) self.add_event_handler( MltkModelEvent.QUANTIZE_STARTUP, self._wandb_init, job_type='quantize' ) self.add_event_handler( MltkModelEvent.POPULATE_TRAIN_CALLBACKS, self._wandb_populate_train_callbacks ) self.add_event_handler( MltkModelEvent.QUANTIZE_SHUTDOWN, self._wandb_save_archive ) self.add_event_handler( MltkModelEvent.GENERATE_EVALUATE_PLOT, self._wandb_save_plot ) self.add_event_handler( MltkModelEvent.EVALUATE_SHUTDOWN, self._wandb_log_evaluation_summary ) self.add_event_handler( MltkModelEvent.AFTER_PROFILE, self._wandb_upload_profiling_results ) def _wandb_load(self, logger:logging.Logger, **kwargs): """This is called after the MltkModel is loaded""" try: wandb_session_id_path = self.get_archive_file('wandb/session_id.txt') with open(wandb_session_id_path, 'r') as f: session_id = f.read().strip() self._attributes['wandb.session_id'] = session_id logger.debug(f'wandb session id: {self.wandb_session_id}') except: pass def _wandb_init(self, job_type:str, post_process:bool, logger:logging.Logger, **kwargs): """This is called at the beginning of train_model(), evaluate_model(), or quantize_model()""" import absl.logging # Do not initialize if: # - User didn't add the --post arg to the command # - Or we're building a temp model if not post_process or (job_type == 'quantize' and kwargs.get('build', False)): if not post_process: logger.debug('No post_processing enabled (e.g. train my_model --post), so not initializing wandb') self._attributes['wandb.is_disabled'] = True return if self.wandb_is_initialized or self.wandb_is_disabled: return if job_type == 'train': resume = 'never' self._attributes['wandb.session_id'] = iso_time_filename_str() os.makedirs(f'{self.log_dir}/wandb', exist_ok=True) with open(f'{self.log_dir}/wandb/session_id.txt', 'w') as f: f.write(self.wandb_session_id) logger.debug(f'wandb session id: {self.wandb_session_id}') else: resume = 'must' if self.wandb_session_id is None: return try: logger.info('Initializing wandb') init_kwargs = self.wandb_init_kwargs settings = init_kwargs.pop('settings', dict(show_info=False)) wandb.init( project=self.name if not self.test_mode_enabled else f'{self.name}-test', job_type=job_type, dir=self.log_dir, id=self.wandb_session_id, resume=resume, settings=settings, **init_kwargs ) except Exception as e: prepend_exception_msg(e, 'Failed to init wandb') raise if job_type == 'train': absl.logging.set_verbosity('ERROR') config = self.wandb_config if config is not None: config['epochs'] = self.epochs config['batch_size'] = self.batch_size config['classes'] = self.classes config['class_weights'] = self.class_weights wandb.config.update(config) self._attributes['wandb.is_initialized'] = True logger.info('wandb initialized') def _wandb_populate_train_callbacks(self, keras_callbacks:list, **kwargs): """This is called after train_model() populates the various Keras callbacks, but before training starts""" if not self.wandb_is_initialized: return callback = WandbCallback() if self.wandb_callback is None else self.wandb_callback if callback: callback = WandbCallback() keras_callbacks.append(callback) weights_dir = f'{self.log_dir}/train/weights' os.makedirs(weights_dir, exist_ok=True) weights_file_format = self.weights_file_format weights_path = f'{weights_dir}/{weights_file_format}' checkpoint_callback = self.wandb_model_checkpoint_callback if checkpoint_callback is not None: checkpoint_callback.filepath = weights_path keras_callbacks.append(checkpoint_callback) def _wandb_log_model_summary(self, summary:str, logger:logging.Logger, **kwargs): """Log the model summary to wandb if we're training""" if self.loaded_subset == 'training': self.wandb_log({'model_summary': _generate_html(summary)}, logger=logger) def _wandb_log_dataset_summary(self, summary:str, logger:logging.Logger, **kwargs): """Log the dataset summary to wandb""" self.wandb_log({f'dataset_summary-{self.loaded_subset}': _generate_html(summary)}, logger=logger) def _wandb_log_evaluation_summary( self, results:EvaluationResults, tflite:bool, logger:logging.Logger, **kwargs ): """Log the evaluation summary to wandb""" summary = results.generate_summary() eval_type = 'tflite' if tflite else 'keras' self.wandb_log({f'eval_summary-{eval_type}': _generate_html(summary)}, logger=logger) self._wandb_save_archive(logger=logger) def _wandb_save_archive(self, logger:logging.Logger, **kwargs): """Upload the model archive to wandb""" if not self.wandb_is_initialized: return self.add_archive_file(f'{self.log_dir}/wandb/session_id.txt') archive_path = self.archive_path model_specification_path = self.model_specification_path dst_dir = wandb.run.settings.files_dir shutil.copy(archive_path, dst_dir) shutil.copy(model_specification_path, dst_dir) logger.debug(f'Uploading to wandb: {archive_path}') self.wandb_save(f'{dst_dir}/{os.path.basename(archive_path)}', logger=logger) self.wandb_save(f'{dst_dir}/{os.path.basename(model_specification_path)}', logger=logger) def _wandb_save_plot(self, name, fig, tflite:bool, logger:logging.Logger, **kwargs): """Log an evaluation plot to wandb""" if not self.wandb_is_initialized: return name += ('-tflite' if tflite else '-keras') logger.debug(f'Saving wandb plot: {name}') with warnings.catch_warnings(): warnings.simplefilter('ignore') self.wandb_log( {name : copy.deepcopy(fig)}) def _wandb_upload_profiling_results( self, results:ProfilingModelResults, logger:logging.Logger, **kwargs ): """Log the profiling results to wandb""" self._wandb_init( job_type='profile', logger=logger, post_process=True ) if not self.wandb_is_initialized: return accelerator = results.accelerator or 'cmsis' self.wandb_log({f'profiling_report-{accelerator}': _generate_html(results.to_string())}) results_dict = results.to_dict() layers = results_dict['layers'] layers_headers = list(layers[0].keys())[1:] layers_data = [] for layer in layers: row = [] for key in layers_headers: row.append(layer[key]) layers_data.append(row) layer_table = wandb.Table(data=layers_data, columns=layers_headers, allow_mixed_types=True) self.wandb_log({f'profiling_layers_table-{accelerator}': layer_table})
def _generate_html(data:str): return wandb.Html( f'<pre class="background-white">{data}</pre>' )