Source code for mltk.core.model.mixins.image_dataset_mixin


import types
from typing import List, Tuple, Union, Dict, Callable

import numpy as np


from mltk.utils.python import prepend_exception_msg
from mltk.utils.process_pool_manager import ProcessPoolManager
from mltk.core.utils import (convert_y_to_labels, get_mltk_logger)

from .data_generator_dataset_mixin import (DataGeneratorDatasetMixin, DataGeneratorContext)
from ..model_attributes import MltkModelAttributesDecorator
from ..model_event import MltkModelEvent


[docs]@MltkModelAttributesDecorator() class ImageDatasetMixin(DataGeneratorDatasetMixin): """Provides image dataset properties to the base :py:class:`~MltkModel`""" @property def dataset(self) -> Union[types.ModuleType,Callable,str]: """Path to the image dataset's python module, a function that manually loads the dataset, or the file path to a directory of samples. If a Python module is provided, it must implement the function: .. highlight:: python .. code-block:: python def load_data(): ... The load_data() function should either return a tuple as: (x_train, y_train), (x_test, y_test) OR it should return the path to a directory containing the dataset's samples. If a function is provided, the function should return the tuple: (x_train, y_train), (x_test, y_test) OR it should return the path to a directory containing the dataset's samples. """ return self._attributes.get_value('dataset.dataset', default=None) @dataset.setter def dataset(self, v: Union[types.ModuleType,Callable,str]): self._attributes['dataset.dataset'] = v @property def follow_links(self) -> bool: """Whether to follow symlinks inside class sub-directories Default: ``True`` """ return self._attributes.get_value('image.follow_links', default=True) @follow_links.setter def follow_links(self, v: bool): self._attributes['image.follow_links'] = v @property def shuffle_dataset_enabled(self) -> bool: """Shuffle the dataset directory once Default: ``false`` - If true, the dataset directory will be shuffled the first time it is processed and and an index containing the shuffled file names is generated in the training log directory. The index is reused to maintain the shuffled order for subsequent processing. - If false, then the dataset samples are sorted alphabetically and saved to an index in the dataset directory. The alphabetical index file is used for subsequent processing. """ return self._attributes.get_value('image.shuffle_dataset_enabled', default=False) @shuffle_dataset_enabled.setter def shuffle_dataset_enabled(self, v: bool): self._attributes['image.shuffle_dataset_enabled'] = v @property def image_classes(self) -> List[str]: """Return a list of class labels the model should classify""" return self._attributes['image.classes'] @image_classes.setter def image_classes(self, v: List[str]): self._attributes['image.classes'] = v @property def image_input_shape(self) -> Tuple[int]: """Return the image input shape as a tuple of integers""" return self._attributes['image.input_shape'] @image_input_shape.setter def image_input_shape(self, v: Tuple[int]): self._attributes['image.input_shape'] = v @property def target_size(self) -> Tuple[int]: """Return the target size of the generated images. The image data generator will automatically resize all images to this size. If omitted, ``my_model.input_shape`` is used. .. note:: This is only used if providing a directory image dataset """ return self._attributes.get_value('image.target_size', default=None) @target_size.setter def target_size(self, v: Tuple[int]): self._attributes['image.target_size'] = v @property def class_mode(self) -> str: """Determines the type of label arrays that are returned. Default: `categorical` - **categorical** - 2D one-hot encoded labels - **binary** - 1D binary labels - **sparse** - 1D integer labels - **input** - images identical to input images (mainly used to work with autoencoders) """ return self._attributes.get_value('image.class_mode', default='categorical') @class_mode.setter def class_mode(self, v: str): self._attributes['image.class_mode'] = v @property def color_mode(self) -> str: """The type of image data to use Default: ``auto`` May be one of the following: - **auto** - Automatically determine the color mode based on the input shape channels - **grayscale** - Convert the images to grayscale (if necessary). The put shape must only have 1 channel - **rgb** - The input shape must only have 3 channels - **rgba** - The input shape must have 4 channels """ return self._attributes.get_value('image.color_mode', default='auto') @color_mode.setter def color_mode(self, v: str): self._attributes['image.color_mode'] = v @property def interpolation(self) -> str: """Interpolation method used to resample the image if the target size is different from that of the loaded image Default: ``bilinear`` Supported methods are ``none``, ``nearest``, ``bilinear``, ``bicubic``, ``lanczos``, ``box`` and ``hamming`` . If ``none`` is used then the generated images are **not automatically resized**. In this case, the :py:class:`mltk.core.preprocess.image.parallel_generator.ParallelImageDataGenerator` ``preprocessing_function`` argument should be used to reshape the image to the expected model input shape. """ return self._attributes.get_value('image.interpolation', default='bilinear') @interpolation.setter def interpolation(self, v: str): self._attributes['image.interpolation'] = v @property def datagen(self): """Training data generator. Should be a reference to a :py:class:`mltk.core.preprocess.image.parallel_generator.ParallelImageDataGenerator` instance OR `tensorflow.keras.preprocessing.image.ImageDataGenerator <https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator>`_ """ return self._attributes.get_value('image.datagen', default=None) @datagen.setter def datagen(self, v): self._attributes['image.datagen'] = v @property def validation_datagen(self): """Validation/evaluation data generator. If omitted, then ``datagen`` is used for validation and evaluation. Should be a reference to a :py:class:`mltk.core.preprocess.image.parallel_generator.ParallelImageDataGenerator` instance OR `tensorflow.keras.preprocessing.image.ImageDataGenerator <https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator>`_ """ return self._attributes.get_value('image.validation_datagen', default=None) @validation_datagen.setter def validation_datagen(self, v): self._attributes['image.validation_datagen'] = v
[docs] def load_dataset( self, subset: str, classes: List[str]=None, max_samples_per_class: int=-1, test:bool = False, **kwargs ): # pylint: disable=arguments-differ """Pre-process the dataset and prepare the model dataset attributes""" self.loaded_subset = subset logger = get_mltk_logger() ProcessPoolManager.set_logger(logger) # First download the dataset if necessary if self.dataset is None: raise Exception('Must specify dataset, e.g.: mltk_model.dataset = tf.keras.datasets.cifar10') self.trigger_event( MltkModelEvent.BEFORE_LOAD_DATASET, subset=subset, test=test, **kwargs ) dataset_data = load_dataset(self.dataset) if not classes: if not self.classes or not isinstance(self.classes, (list,tuple)): raise Exception('Must specify mltk_model.classes which must be a list of class labels') classes = self.classes if self.input_shape is None or len(self.input_shape) != 3: raise Exception('Must specify mltk_model.input_shape which must be a tuple (height, width, depth)') if self.datagen is None: raise Exception('Must specify mltk_model.datagen') if not hasattr(self, 'batch_size'): logger.warning('MltkModel does not define batch_size, defaulting to 32') batch_size = 32 else: batch_size = self.batch_size input_depth = self.input_shape[2] color_mode = self.color_mode if color_mode == 'auto': if input_depth == 1: color_mode = 'grayscale' elif input_depth == 3: color_mode = 'rgb' else: raise Exception('mltk_model.input_shape[2] must be 1 or 3 (i.e. grayscale or rgb)') if input_depth == 1 and color_mode != 'grayscale': logger.warning('mltk_model.input_shape[2]=1 but mltk_model.color_mode != grayscale') if input_depth == 3 and color_mode != 'rgb': logger.warning('mltk_model.input_shape[2]=3 but mltk_model.color_mode != rgb') target_size = self.target_size or self.input_shape[:2] logger.debug(f'Target image size={target_size}') eval_shuffle = False eval_augmentation_enabled = False if test: batch_size = 3 max_samples_per_class = batch_size if hasattr(self, 'batch_size'): self.batch_size = batch_size self.datagen.max_batches_pending = 1 logger.debug(f'Test mode enabled, forcing max_samples_per_class={max_samples_per_class}, batch_size={batch_size}') if self.loaded_subset == 'evaluation': if hasattr(self, 'eval_shuffle'): eval_shuffle = self.eval_shuffle if hasattr(self, 'eval_augment'): eval_augmentation_enabled = self.eval_augment if max_samples_per_class == -1 and hasattr(self, 'eval_max_samples_per_class'): max_samples_per_class = self.eval_max_samples_per_class train_datagen = None validation_datagen = None if self.loaded_subset == 'training': training_datagen_creator = self.get_datagen_creator('training') # Get the validation data generator if one was specified # otherwise fallback to the training data generator validation_datagen_creator = self.get_datagen_creator('validation') # If a custom loading function was specified if isinstance(dataset_data, (tuple,list)): if not(len(dataset_data) == 2 or len(dataset_data) == 4): raise Exception('mltk_model.dataset should return a tuple of the form: (x_train, y_train), (x_test, y_test)') if len(dataset_data) == 2: train, test = dataset_data if not isinstance(train, (list, tuple)) or len(train) != 2: raise Exception('mltk_model.dataset should return a tuple of the form: (x_train, y_train), (x_test, y_test)') if not isinstance(test, (list, tuple)) or len(test) != 2: raise Exception('mltk_model.dataset should return a tuple of the form: (x_train, y_train), (x_test, y_test)') x_train, y_train = train x_test, y_test = test else: x_train, y_train, x_test, y_test = dataset_data if self.class_mode == 'categorical' and y_train.shape[-1] != len(classes): raise Exception(f'y_train.shape[-1] ({y_train.shape[-1]}) != len(mltk_model.classes) ({len(classes)}). ' \ 'Perhaps you need to convert your dataset to categorical?') if self.class_mode == 'categorical' and y_test.shape[-1] != len(classes): raise Exception(f'y_test.shape[-1] ({y_train.shape[-1]}) != len(mltk_model.classes) ({len(classes)}). ' \ 'Perhaps you need to convert your dataset to categorical?') if self.loaded_subset == 'training': if max_samples_per_class != -1 and self.class_mode == 'categorical': x_train, y_train = _clamp_max_samples_per_class(x_train, y_train, max_samples_per_class) train_datagen = training_datagen_creator.flow( x_train, y_train, batch_size=batch_size, shuffle=True ) self.class_counts['training'] = _get_class_counts(train_datagen.y, classes=classes, class_mode=self.class_mode) if max_samples_per_class != -1 and self.class_mode == 'categorical': x_test, y_test = _clamp_max_samples_per_class(x_test, y_test, max_samples_per_class) validation_datagen = validation_datagen_creator.flow( x_test, y_test, batch_size=batch_size, shuffle=eval_shuffle if self.loaded_subset == 'evaluation' else True ) self.class_counts['validation'] = _get_class_counts(validation_datagen.y, classes=classes, class_mode=self.class_mode) # If a directory was specified elif isinstance(dataset_data, str): from mltk.core.preprocess.image.parallel_generator import ParallelImageDataGenerator shuffle_index_dir = None if self.shuffle_dataset_enabled: shuffle_index_dir = self.get_shuffle_index_dir() logger.debug(f'shuffle_index_dir={shuffle_index_dir}') logger.debug(f'Dataset directory: {dataset_data}') batch_shape = (batch_size,) + tuple(self.input_shape) logger.debug(f'Batch shape: {batch_shape}') kwargs = dict( directory=dataset_data, target_size=target_size, batch_shape=batch_shape, classes=classes, class_mode=self.class_mode, color_mode=color_mode, interpolation=self.interpolation, follow_links=self.follow_links, shuffle_index_dir=shuffle_index_dir, list_valid_filenames_in_directory_function=_get_list_valid_filenames_function(self.dataset), ) if self.loaded_subset == 'training': training_datagen_creator.max_samples_per_class = max_samples_per_class if isinstance(training_datagen_creator, ParallelImageDataGenerator): kwargs['class_counts'] = self.class_counts['training'] train_datagen = training_datagen_creator.flow_from_directory( subset='training', shuffle=True, **kwargs ) kwargs.pop('class_counts', None) if self.loaded_subset in ('training', 'validation'): validation_datagen_creator.max_samples_per_class = max_samples_per_class if isinstance(validation_datagen_creator, ParallelImageDataGenerator): kwargs['class_counts'] = self.class_counts['validation'] validation_datagen = validation_datagen_creator.flow_from_directory( subset='validation', shuffle=True, **kwargs ) kwargs.pop('class_counts', None) if self.loaded_subset == 'evaluation': validation_datagen_creator.max_samples_per_class = max_samples_per_class validation_datagen_creator.validation_augmentation_enabled = eval_augmentation_enabled if isinstance(validation_datagen_creator, ParallelImageDataGenerator): kwargs['class_counts'] = self.class_counts['validation'] validation_datagen = validation_datagen_creator.flow_from_directory( subset='validation', shuffle=eval_shuffle, **kwargs ) kwargs.pop('class_counts', None) else: raise Exception( 'mltk_model.dataset must return return a tuple as: (x_train, y_train), (x_test, y_test)' ' or a file path to a directory of samples' ) # Fix issue with: # tensorflow.keras.preprocessing.image.ImageDataGenerator _patch_image_iterator(validation_datagen) if self.class_counts['validation']: validation_datagen.max_samples = sum(self.class_counts['validation'].values()) self.x = None self.validation_data = None if self.loaded_subset == 'training': self.x = train_datagen if self.loaded_subset in ('training', 'validation'): self.validation_data = validation_datagen if self.loaded_subset == 'evaluation': self.x = train_datagen if validation_datagen is None else validation_datagen self.datagen_context = DataGeneratorContext( subset = self.loaded_subset, train_datagen = train_datagen, train_class_counts = self.class_counts['training'], validation_datagen = validation_datagen, validation_class_counts = self.class_counts['validation'] ) self.trigger_event( MltkModelEvent.AFTER_LOAD_DATASET, subset=subset, test=test, **kwargs )
def _register_attributes(self): from mltk.core.preprocess.image.parallel_generator import ParallelImageDataGenerator from tensorflow.keras.preprocessing.image import ImageDataGenerator self._attributes.register('image.follow_links', dtype=bool) self._attributes.register('image.shuffle_dataset_enabled', dtype=bool) self._attributes.register('image.input_shape', dtype=(list,tuple)) self._attributes.register('image.target_size', dtype=(list,tuple)) self._attributes.register('image.classes', dtype=(list,tuple)) self._attributes.register('image.class_mode', dtype=str) self._attributes.register('image.color_mode', dtype=str) self._attributes.register('image.interpolation', dtype=str) self._attributes.register('image.datagen', dtype=(ParallelImageDataGenerator, ImageDataGenerator)) self._attributes.register('image.validation_datagen', dtype=(ParallelImageDataGenerator, ImageDataGenerator)) # We cannot call attributes while we're registering them # So we return a function that will be called after # all the attributes are registered def register_parameters_populator(): self.add_model_parameter_populate_callback(self._populate_image_dataset_model_parameters) return register_parameters_populator def _populate_image_dataset_model_parameters(self): """Populate the image processing parameters required at inference time These parameters will be added to the compiled .tflite TfliteModelParameters metadata. At inference time, these paramaters are retrieved from the generated .tflite and used them to process the input images. NOTE: This is invoked during the compile_model() API execution. """ if self.datagen is not None: self.set_model_parameter('samplewise_norm.rescale', float(self.datagen.rescale or 0.)) self.set_model_parameter('samplewise_norm.mean_and_std', self.datagen.samplewise_center and self.datagen.samplewise_std_normalization)
def load_dataset(dataset) -> Union[str,tuple]: if isinstance(dataset,str): return dataset if callable(dataset): try: return dataset() except Exception as e: prepend_exception_msg(e, f'Exception while invoking mltk_model.dataset function: {dataset}') raise if isinstance(dataset, (types.ModuleType, object)): if not hasattr(dataset, 'load_data'): raise Exception('If a module or class is set in mltk_model.dataset, the the module/class must specify the function: load_data()') try: return dataset.load_data() except Exception as e: prepend_exception_msg(e, f'Exception while invoking mltk_model.dataset.load_data(): {dataset}') raise raise Exception('mltk_model.dataset must either be file path to a dictionary or callback function') def _get_list_valid_filenames_function(dataset): if isinstance(dataset, (types.ModuleType, object)): if hasattr(dataset, 'list_valid_filenames_in_directory'): return getattr(dataset, 'list_valid_filenames_in_directory') return None def _patch_image_iterator(datagen): """Patch the KerasImageIterator so that tensorflow.keras.preprocessing.image.ImageDataGenerator properly iterates while predicting """ from mltk.core.keras import ImageIterator if not isinstance(datagen, ImageIterator): return datagen.max_samples = -1 datagen.sample_count = 0 def _patched_next(self): """For python 2.x. # Returns The next batch. """ if datagen.max_samples > 0 and datagen.sample_count > datagen.max_samples: raise StopIteration() with self.lock: index_array = next(self.index_generator) # The transformation of images is not under thread lock # so it can be done in parallel batch_samples = self._get_batches_of_transformed_samples(index_array) # pylint: disable=protected-access datagen.sample_count += len(batch_samples[0]) return batch_samples datagen.next = types.MethodType(_patched_next, datagen) def _clamp_max_samples_per_class(x, y, max_samples_per_class): """Clamp the number of samples per class to a maximum if necessary""" labels = convert_y_to_labels(y) _, class_max_counts = np.unique(labels, return_counts=True) n_samples = 0 for c in class_max_counts: n_samples += min(c, max_samples_per_class) x_truncated = np.empty((n_samples, *x.shape[1:]), dtype=x.dtype) y_truncated = np.empty((n_samples, *y.shape[1:]), dtype=y.dtype) class_counts = np.zeros((len(class_max_counts),), dtype=np.int32) index = 0 for i, class_id in enumerate(labels): if index == n_samples: break if class_counts[class_id] == max_samples_per_class: continue class_counts[class_id] += 1 x_truncated[index, :] = x[i] y_truncated[index, :] = y[i] index += 1 return x_truncated, y_truncated def _get_class_counts(y, classes:List[str], class_mode:str) -> Dict[str,int]: class_counts = {} if class_mode == 'categorical': y = convert_y_to_labels(y) if class_mode != 'input': for i, class_name in enumerate(classes): class_counts[class_name] = 0 counts = np.bincount(y) for i, count in enumerate(counts): class_name = classes[i] class_counts[class_name] = count return class_counts