Source code for mltk.core.model.mixins.evaluate_autoencoder_mixin
from typing import Callable, List
from .evaluate_classifier_mixin import EvaluateClassifierMixin
from ..model_attributes import MltkModelAttributesDecorator, CallableType
[docs]@MltkModelAttributesDecorator()
class EvaluateAutoEncoderMixin(EvaluateClassifierMixin):
"""Provides evaluation properties and methods to the base :py:class:`~MltkModel`
.. note:: This mixin is specific to "auto-encoder" models
Refer to the `Model Evaluation <https://siliconlabs.github.io/mltk/docs/guides/model_evaluation.html>`_ guide for more details.
"""
@property
def scoring_function(self) -> Callable:
"""The auto-encoder scoring function to use during evaluation
If `None`, then use the `mltk_model.loss` function
Default: `None`
"""
return self._attributes.get_value('eval_autoencoder.scoring_function', default=None)
@scoring_function.setter
def scoring_function(self, v: Callable):
self._attributes['eval_autoencoder.scoring_function'] = v
@property
def eval_classes(self) -> List[str]:
"""List if classes to use for evaluation.
The first element should be considered the 'normal' class, every other class is considered abnormal and compared independently.
This is used if the `--classes` argument is not supplied to the `eval` command.
Default: `[normal, abnormal]`
"""
return self._attributes.get_value('eval_autoencoder.classes', default=['normal', 'abnormal'] )
@eval_classes.setter
def eval_classes(self, v: List[str]):
self._attributes['eval_autoencoder.classes'] = v
[docs] def get_scoring_function(self) -> Callable:
"""Return the scoring function used during evaluation"""
from mltk.core.keras.losses import (
Correlation,
MeanSquaredError,
MeanAbsoluteError,
mse_loss_func,
corr_loss_func,
mae_loss_func
)
if self.scoring_function is not None:
return self.scoring_function
loss = self.loss
if loss in ('mse', 'mean_squared_error') or isinstance(loss, MeanSquaredError):
return mse_loss_func
elif loss in ('mae', 'mean_absolute_error') or isinstance(loss, MeanAbsoluteError):
return mae_loss_func
elif loss in ('corr', 'correlation') or isinstance(loss, Correlation):
if not corr_loss_func:
raise RuntimeError('Failed to get correlation loss function, ensure the Tensorflow-Probability package is properly installed')
return corr_loss_func
else:
raise RuntimeError(
'Only model loss functions: "mse", "mae", "corr" are supported by default.\n'
'You must specify mltk_model.scoring_function for your model'
)
def _register_attributes(self):
self._attributes.register('eval_autoencoder.scoring_function', dtype=CallableType)
self._attributes.register('eval_autoencoder.classes', dtype=(list,tuple))