from typing import Union, Tuple, List, Callable
import types
import numpy as np
from mltk.utils.python import prepend_exception_msg
from mltk.utils.process_pool_manager import ProcessPoolManager
from mltk.core.utils import get_mltk_logger
from .data_generator_dataset_mixin import (DataGeneratorDatasetMixin, DataGeneratorContext)
from ..model_attributes import MltkModelAttributesDecorator
from ..model_event import MltkModelEvent
[docs]@MltkModelAttributesDecorator()
class AudioDatasetMixin(DataGeneratorDatasetMixin):
"""Provides audio dataset properties to the base :py:class:`~MltkModel`
.. seealso::
- `AudioFeatureGenerator documentation <https://siliconlabs.github.io/mltk/docs/audio/audio_feature_generator.html>`_
- `AudioFeatureGenerator API docs <https://siliconlabs.github.io/mltk/docs/python_api/data_preprocessing/audio_feature_generator.html>`_
- `ParallelAudioDataGenerator API docs <https://siliconlabs.github.io/mltk/docs/python_api/data_preprocessing/audio_data_generator.html>`_
"""
@property
def dataset(self) -> Union[str,types.ModuleType,Callable]:
"""Path to the audio dataset's python module, a function
that manually loads the datset, or the file path to a directory of samples.
If a Python module is provided, it must implement the function:
.. highlight:: python
.. code-block:: python
def load_data():
...
which should return the file path to the dataset's directory
If a function is provided, the function should return
the path to a directory containing the dataset's samples.
"""
return self._attributes['dataset.dataset']
@dataset.setter
def dataset(self, v: Union[str,types.ModuleType]):
self._attributes['dataset.dataset'] = v
@property
def follow_links(self) -> bool:
"""Whether to follow symlinks inside class subdirectories
Default: ``True``
"""
return self._attributes.get_value('audio.follow_links', default=True)
@follow_links.setter
def follow_links(self, v: bool):
self._attributes['audio.follow_links'] = v
@property
def shuffle_dataset_enabled(self) -> bool:
"""Shuffle the dataset directory once
Default: ``false``
- If true, the dataset directory will be shuffled the first time it is processed and
and an index containing the shuffled file names is generated in the training log directory.
The index is reused to maintain the shuffled order for subsequent processing.
- If false, then the dataset samples are sorted alphabetically and saved to an index in the dataset directory.
The alphabetical index file is used for subsequent processing.
"""
return self._attributes.get_value('audio.shuffle_dataset_enabled', default=False)
@shuffle_dataset_enabled.setter
def shuffle_dataset_enabled(self, v: bool):
self._attributes['audio.shuffle_dataset_enabled'] = v
@property
def class_mode(self) -> str:
"""Determines the type of label arrays that are returned.
Default: ``categorical``
- **categorical** - 2D one-hot encoded labels
- **binary** - 1D binary labels
- **sparse** - 1D integer labels
- **input** - images identical to input images (mainly used to work with autoencoders)
"""
return self._attributes.get_value('audio.class_mode', default='categorical')
@class_mode.setter
def class_mode(self, v: str):
self._attributes['audio.class_mode'] = v
@property
def audio_classes(self) -> List[str]:
"""Return a list of class labels the model should classify"""
return self._attributes['audio.classes']
@audio_classes.setter
def audio_classes(self, v: List[str]):
self._attributes['audio.classes'] = v
@property
def audio_input_shape(self) -> Tuple[int, int, int]:
"""Get the shape of the spectrogram generated by the :py:class:`mltk.core.preprocess.audio.audio_feature_generator.AudioFeatureGenerator` as (height, width, 1)
.. note::
If frontend_enabled = True then the input size is automatically calculated based the on the
:py:class:`mltk.core.preprocess.audio.audio_feature_generator.AudioFeatureGeneratorSettings`
If frontend_enabled = False then the input size must be manually specified.
"""
if self.datagen is not None and self.datagen.frontend_enabled:
if self._attributes.value_is_set('audio.manual_in_shape'):
get_mltk_logger().warning('When ParallelAudioDataGenerator.frontend_enabled=True, the model.input shape is automatically calculated, however, the input_shape has been manually set')
spectrogram_shape = self.datagen.frontend_settings.spectrogram_shape
if self.datagen.add_channel_dimension:
spectrogram_shape = spectrogram_shape + (1,)
return spectrogram_shape
else:
return self._attributes['audio.manual_in_shape']
@audio_input_shape.setter
def audio_input_shape(self, v):
if self.datagen is not None and self.datagen.frontend_enabled:
raise Exception('mltk_model.input_shape is determined dynamically based on the AudioFeatureGeneratorSettings when datagen.frontend_enabled=True. In this case, it cannot be manually set')
self._attributes['audio.manual_in_shape'] = v
@property
def sample_length_ms(self) -> int:
"""Get the data generator sample length in milliseconds"""
if self.datagen is None:
raise Exception('You must specify mltk_model.datagen')
return self.datagen.sample_length_ms
@sample_length_ms.setter
def sample_length_ms(self, v: int):
if self.datagen is None:
raise Exception('You must specify mltk_model.datagen')
self.datagen.sample_length_ms = v
@property
def sample_rate_hz(self) -> int:
"""Get the data generator sample rate in hertz"""
if self.datagen is None:
raise Exception('You must specify mltk_model.datagen')
return self.datagen.sample_rate_hz
@sample_rate_hz.setter
def sample_rate_hz(self, v: int):
if self.datagen is None:
raise Exception('You must specify mltk_model.datagen')
self.datagen.sample_rate_hz = v
@property
def frontend_settings(self):
"""Get the data generator's :py:class:`mltk.core.preprocess.audio.audio_feature_generator.AudioFeatureGeneratorSettingsSettings` """
if self.datagen is None:
raise Exception('You must specify mltk_model.datagen')
return self.datagen.frontend_settings
@property
def datagen(self):
"""Training data generator.
Should be a reference to a :py:attr:`mltk.core.preprocess.audio.parallel_generator.ParallelAudioDataGenerator` instance
"""
return self._attributes.get_value('audio.datagen', default=None)
@datagen.setter
def datagen(self, v):
self._attributes['audio.datagen'] = v
@property
def validation_datagen(self):
"""Validation/evaluation data generator.
If omitted, then :py:attr:`~datagen` is used for validation and evaluation.
Should be a reference to a :py:attr:`mltk.core.preprocess.audio.parallel_generator.ParallelAudioDataGenerator` instance
"""
return self._attributes.get_value('audio.validation_datagen', default=None)
@validation_datagen.setter
def validation_datagen(self, v):
self._attributes['audio.validation_datagen'] = v
[docs] def load_dataset(
self,
subset: str,
classes: List[str]=None,
max_samples_per_class: int=-1,
test:bool = False,
**kwargs,
): # pylint: disable=arguments-differ
"""Pre-process the dataset and prepare the model dataset attributes
Args:
subset: Data subset name
"""
self.loaded_subset = subset
logger = get_mltk_logger()
ProcessPoolManager.set_logger(logger)
if self.datagen is None:
raise Exception('Must specify mltk_model.datgen')
if not classes:
if not self.classes or not isinstance(self.classes, (list,tuple)):
raise Exception('Must specify mltk_model.classes which must be a list of class labels')
classes = self.classes
# First download the dataset if necessary
if self.dataset is None:
raise Exception('Must specify mltk_model.dataset')
self.trigger_event(
MltkModelEvent.BEFORE_LOAD_DATASET,
subset=subset,
test=test,
**kwargs
)
dataset_dir = _load_dataset(self.dataset)
if not isinstance(dataset_dir, str):
raise Exception('Dataset must be a path to a directory')
if not hasattr(self, 'batch_size'):
logger.warning('MltkModel does not define batch_size, defaulting to 32')
batch_size = 32
else:
batch_size = self.batch_size
shuffle_index_dir = None
if self.shuffle_dataset_enabled:
shuffle_index_dir = self.get_shuffle_index_dir()
logger.debug(f'shuffle_index_dir={shuffle_index_dir}')
eval_shuffle = False
eval_augmentation_enabled = False
if test:
batch_size = 3
max_samples_per_class = batch_size
if hasattr(self, 'batch_size'):
self.batch_size = batch_size
self.datagen.max_batches_pending = 1
logger.debug(f'Test mode enabled, forcing max_samples_per_class={max_samples_per_class}, batch_size={batch_size}')
if self.loaded_subset == 'evaluation':
if hasattr(self, 'eval_shuffle'):
eval_shuffle = self.eval_shuffle
if hasattr(self, 'eval_augment'):
eval_augmentation_enabled = self.eval_augment
if max_samples_per_class == -1 and hasattr(self, 'eval_max_samples_per_class'):
max_samples_per_class = self.eval_max_samples_per_class
train_datagen = None
validation_datagen = None
if self.loaded_subset == 'training':
training_datagen_creator = self.get_datagen_creator('training')
if training_datagen_creator is None:
raise Exception('Must specify mltk_model.datagen for model')
# Get the validation data generator if one was specified
# otherwise fallback to the training data generator
validation_datagen_creator = self.get_datagen_creator('validation')
logger.debug(f'Dataset directory: {dataset_dir}')
kwargs = dict(
directory=dataset_dir,
target_size=self.input_shape[:2], # Get the height and width
classes=classes,
class_mode=self.class_mode,
follow_links=self.follow_links,
batch_size=batch_size,
max_samples_per_class=max_samples_per_class,
shuffle_index_dir=shuffle_index_dir,
list_valid_filenames_in_directory_function=_get_list_valid_filenames_function(self.dataset),
)
if self.loaded_subset == 'training':
train_datagen = training_datagen_creator.flow_from_directory(
subset='training',
shuffle=True,
class_counts=self.class_counts['training'],
**kwargs
)
if self.loaded_subset in ('training', 'validation'):
validation_datagen = validation_datagen_creator.flow_from_directory(
subset='validation',
shuffle=True,
class_counts=self.class_counts['validation'],
**kwargs
)
if self.loaded_subset == 'evaluation':
validation_datagen_creator.validation_augmentation_enabled = eval_augmentation_enabled
validation_datagen = validation_datagen_creator.flow_from_directory(
subset='validation',
shuffle=eval_shuffle,
class_counts=self.class_counts['validation'],
**kwargs
)
self.x = None
self.validation_data = None
if self.loaded_subset == 'training':
self.x = train_datagen
if self.loaded_subset in ('training', 'validation'):
self.validation_data = validation_datagen
if self.loaded_subset == 'evaluation':
self.x = train_datagen if validation_datagen is None else validation_datagen
self.datagen_context = DataGeneratorContext(
subset = self.loaded_subset,
train_datagen = train_datagen,
validation_datagen = validation_datagen,
train_class_counts=self.class_counts['training'],
validation_class_counts=self.class_counts['validation']
)
self.trigger_event(
MltkModelEvent.AFTER_LOAD_DATASET,
subset=subset,
test=test,
**kwargs
)
def _register_attributes(self):
from mltk.core.preprocess.audio.parallel_generator import ParallelAudioDataGenerator
self._attributes.register('audio.follow_links', dtype=bool)
self._attributes.register('audio.shuffle_dataset_enabled', dtype=bool)
self._attributes.register('audio.class_mode', dtype=str)
self._attributes.register('audio.datagen', dtype=ParallelAudioDataGenerator)
self._attributes.register('audio.validation_datagen', dtype=ParallelAudioDataGenerator)
def _set_audio_input_shape(v):
self.audio_input_shape = v
# If datagen.frontend_enabled = True then this is used
self._attributes.register('audio.input_shape', lambda: self.audio_input_shape, setter=_set_audio_input_shape)
# If datagen.frontend_enabled = False then this is used
self._attributes.register('audio.manual_in_shape', dtype=(list,tuple))
self._attributes.register('audio.classes', dtype=(list,tuple))
# We cannot call attributes while we're registering them
# So we return a function that will be called after
# all the attributes are registered
def register_parameters_populator():
self.add_model_parameter_populate_callback(self._populate_audio_dataset_model_parameters)
return register_parameters_populator
def _populate_audio_dataset_model_parameters(self):
"""Populate the AudioPipeline parameters required by the device at runtime
These parameters will be added to the compiled .tflite ModelParameters metadata.
The device retrieves these paramaters from the generated .tflite at run-time and
uses them to process the microphone audio in the AudioPipeline.
NOTE: This is invoked during the compile_model() API execution.
"""
if self.datagen is not None:
parameters = self.model_parameters
parameters['samplewise_norm.rescale'] = float(self.datagen.rescale or 0)
parameters['samplewise_norm.mean_and_std'] = self.datagen.samplewise_center and self.datagen.samplewise_std_normalization
parameters.update(self.frontend_settings)
def _load_dataset(dataset) -> Union[str,tuple]:
if isinstance(dataset,str):
return dataset
if callable(dataset):
try:
return dataset()
except Exception as e:
prepend_exception_msg(e, f'Exception while invoking mltk_model.dataset function: {dataset}')
raise
if isinstance(dataset, (types.ModuleType, object)):
if not hasattr(dataset, 'load_data'):
raise Exception('If a module or class is set in mltk_model.dataset, the the module/class must specify the function: load_data()')
try:
return dataset.load_data()
except Exception as e:
prepend_exception_msg(e, f'Exception while invoking mltk_model.dataset.load_data(): {dataset}')
raise
raise Exception('mltk_model.dataset must either be file path to a dictionary or callback function')
def _get_list_valid_filenames_function(dataset):
if isinstance(dataset, (types.ModuleType, object)):
if hasattr(dataset, 'list_valid_filenames_in_directory'):
return getattr(dataset, 'list_valid_filenames_in_directory')
return None