Source code for mltk.core.preprocess.image.parallel_generator.iterator

"""Utilities for real-time data augmentation on image data.
"""
import os
import sys
import time
import threading
import inspect
import queue
from typing import List, Tuple
import random
import numpy as np

from mltk.core.keras import (
    array_to_img,
    img_to_array,
    load_img
)
from mltk.core import get_mltk_logger
from mltk.core.keras import DataSequence
from mltk.utils.process_pool import ProcessPool, calculate_n_jobs


class ParallelIterator(DataSequence):
    """Base class for image data iterators.

    # Arguments
        n: Integer, total number of samples in the dataset to loop over.
        batch_shape: Integer, shape of a batch.
        shuffle: Boolean, whether to shuffle the data between epochs.
        seed: Random seeding for data shuffling.
    """
    white_list_formats = ('png', 'jpg', 'jpeg', 'bmp', 'ppm', 'tif', 'tiff', 'npy')

    def __init__(self, n, batch_shape, shuffle, seed, process_params):
        super().__init__()

        self.n = n
        self.batch_shape = batch_shape
        self.seed = seed
        self.total_batches_seen = 0
        self.batch_index = 0
        self.shuffle = shuffle
        self.process_params = process_params
        self.check_reuse_batch_zero = True

        n_jobs = calculate_n_jobs(self.cores)
        # Use half the number of jobs for the validation subset
        if process_params.validation_split and process_params.subset == 'validation':
            n_jobs = max(int(n_jobs*.5), 1)

        get_batch_function = self.process_params.get_batch_function or get_batches_of_transformed_samples
        self.pool = ProcessPool(
            name=self.process_params.subset,
            entry_point=get_batch_function,
            n_jobs=n_jobs,
            debug=self.debug,
            disable_gpu_in_subprocesses=self.disable_gpu_in_subprocesses,
            logger=get_mltk_logger()
        )

        self.batch_generation_started = threading.Event()
        self.current_batch_finished = threading.Event()
        self.current_batch_finished.set()

        self.batch_data = BatchData(
            len(self), 
            shuffle,
            pool=self.pool
        )
       
        self.batch_thread = threading.Thread(
            target=self._generate_batch_data_safe, 
            name=f'Batch data generator:{process_params.subset}',
            daemon=True
        )
        self.batch_thread.start()

 
    @property
    def is_shutdown(self) -> bool:
        return not self.pool.is_running

    @property
    def is_running(self) -> bool:
        return not self.is_shutdown

    @property
    def batch_size(self) -> int:
        return self.batch_shape[0]

        
    def reset(self):
        self.batch_generation_started.clear()
        self.current_batch_finished.set()
        self.batch_data.reset()
        self.batch_index = 0
    
    
    def shutdown(self, wait=True):
        if wait:
            self.current_batch_finished.wait(30)
        self.reset()
        self.pool.shutdown()
       


    def __getitem__(self, idx):
        if idx >= len(self):
            raise ValueError(
                f'Asked to retrieve element {idx}, but the Sequence has length {len(self)}'
            )

        # Some TF APIs use a "peek_and_restore" operation on the first batch.
        # To account for this, we check the callback stack for a function that has "peek" in it
        # If so, we notify the batch_data.get() API that it should save this batch as it will be re-used
        save_batch_zero = False
        if idx == 0 and self.check_reuse_batch_zero:
            self.check_reuse_batch_zero = False
            callback_function_name = sys._getframe(1).f_code.co_name
            if 'peek' in callback_function_name:
                save_batch_zero = True

        self.batch_generation_started.set()
        self.current_batch_finished.clear()
        
        retval, is_last = self.batch_data.get(idx, save_batch_zero=save_batch_zero)
        if is_last:
            self.current_batch_finished.set()
        return retval


    def __len__(self):
        return (self.n + self.batch_size -1) // self.batch_size  # round up


    def on_epoch_end(self):
        pass


    def __iter__(self):
        # Needed if we want to do something like:
        # for x, y in data_gen.flow(...):
        self.batch_index = 0
        return self


    def __next__(self, *args, **kwargs):
        return self.next(*args, **kwargs)


    def next(self):
        """For python 2.x.

        # Returns
            The next batch.
        """
        if self.batch_index >= len(self):
            # Clear the started flag, but do NOT reset
            # This way we don't waste any processed batch data
            self.batch_generation_started.clear()
            raise StopIteration()
        
        self.batch_generation_started.set()
        self.current_batch_finished.clear()
        retval, is_last = self.batch_data.get(self.batch_index)
        if is_last:
            self.current_batch_finished.set()
        
        self.batch_index += 1
        
        return retval


    def _generate_batch_data_safe(self):
        try:
            self._generate_batch_data()
        except Exception as e:
            if not self.is_shutdown:
                get_mltk_logger().error(f'Exception during batch data processing, err: {e}', exc_info=e)
                return
            self.shutdown(wait=False)
    
    
    def _generate_batch_data(self):
        while self.is_running:
            # Wait for training to start
            if not self.batch_generation_started.wait(timeout=0.1):
                continue
        
            if self.seed is not None:
                np.random.seed(self.seed + self.total_batches_seen)
            
            # If the number of samples is not a multiple of the batch size
            # then the last batch needs to wrap to the beginning of the sample indices
            wrap_length = (len(self) * self.batch_size) - self.n
            if self.shuffle:
                index_array = np.random.permutation(self.n)
                if wrap_length > 0:
                    index_array = np.concatenate((index_array, np.random.permutation(wrap_length))) 
            else:
                index_array = np.arange(self.n)
                if wrap_length > 0:
                    index_array = np.concatenate((index_array, np.arange(wrap_length))) 


            self.batch_data.start_batch()
    
            while self.batch_data.have_more_indices:
                while self.batch_data.request_count == 0 and self.batch_data.qsize() > self.max_batches_pending:
                    if self.is_shutdown:
                        return
                    if not self.batch_generation_started.is_set():
                        break
                    self.batch_data.wait()
                
                if not self.batch_generation_started.is_set():
                    break
                
                current_batch_index = self.batch_data.next_index()
                self.total_batches_seen += 1
                offset = current_batch_index*self.batch_size
                batch_index_chunk = index_array[offset:offset+self.batch_size]
                
                batch_filenames = []
                batch_classes = []
                
                for batch_index in batch_index_chunk:
                    batch_filenames.append(self.filenames[batch_index])
                    batch_classes.append(self.classes[batch_index])

                self._invoke_processing(current_batch_index, batch_filenames, batch_classes)


    def _invoke_processing(
        self, 
        batch_index:int, 
        batch_filenames:List[str], 
        batch_classes:List[int]
    ):
        try:
            self.pool(
                batch_index, 
                batch_filenames, 
                batch_classes, 
                params=self.process_params,
                pool_callback=self._pool_callback
            )
        except Exception as e:
            if not self.is_running:
                raise


    def _pool_callback(self, results):
        if results is not None:
            self.batch_data.put(results[0], results[1])
    


[docs]class ParallelProcessParams(): """Adds methods related to getting batches from filenames It includes the logic to transform image files to batches. """
[docs] def __init__( self, image_data_generator, target_size, batch_shape, color_mode, data_format, save_to_dir, save_prefix, save_format, subset, interpolation, class_indices, dtype, directory, class_mode, get_batch_function, preprocessing_function, noaug_preprocessing_function ): """Sets attributes to use later for processing files into a batch.""" self.class_indices = class_indices self.dtype = dtype self.directory = directory self.class_mode = class_mode self.image_data_generator = image_data_generator self.target_size = tuple(target_size) self.batch_shape = batch_shape self.get_batch_function = get_batch_function self.preprocessing_function = preprocessing_function self.noaug_preprocessing_function = noaug_preprocessing_function if color_mode not in {'rgb', 'rgba', 'grayscale'}: raise ValueError('Invalid color mode:', color_mode, '; expected "rgb", "rgba", or "grayscale".') self.color_mode = color_mode self.data_format = data_format self.save_to_dir = save_to_dir self.save_prefix = save_prefix self.save_format = save_format self.interpolation = interpolation self.validation_split = self.image_data_generator._validation_split if subset is not None: validation_split = self.image_data_generator._validation_split if subset == 'validation': split = (0, validation_split) elif subset == 'training': split = (validation_split, 1) else: raise ValueError( 'Invalid subset name: %s;' 'expected "training" or "validation"' % (subset,)) else: split = None self.split = split self.subset = subset
@property def n_classes(self) -> int: return len(self.class_indices)
def get_batches_of_transformed_samples( batch_index:int, filenames:List[str], classes:List[int], params:ParallelProcessParams ) -> Tuple[int, Tuple[np.ndarray, np.ndarray]]: """Gets a batch of transformed samples. Arguments: batch_index: Index of this batch filenames: List of filenames for this batch classes: List of class ids mapping to the filenames list params: Generator parameters # Returns A batch of transformed samples: batch_index, (batch_x, batch_y) """ # Ensure the RNG is unique for each batch if params.subset != 'validation' or params.image_data_generator.validation_augmentation_enabled: random.seed(batch_index + int(time.time())) np.random.seed(batch_index + int(time.time())) batch_size = params.batch_shape[0] assert len(filenames) == params.batch_shape[0] if isinstance(filenames[0], (list,tuple)): batch_x = [] for _ in range(len(filenames[0])): batch_x.append(np.empty(params.batch_shape, dtype=params.dtype)) else: batch_x = np.empty(params.batch_shape, dtype=params.dtype) # build batch of image data def _process_image_file(class_id, filename) -> np.ndarray: filepath = f'{params.directory}/{filename}' if filepath.endswith('npy'): x = np.load(filepath) else: # If set to 'None' don't rescale and let the user do it in preprocessing_func() if params.interpolation is None or params.interpolation.lower() == 'none': img = load_img(filepath, color_mode=params.color_mode) else: img = load_img(filepath, color_mode=params.color_mode, target_size=params.target_size, interpolation=params.interpolation) x = img_to_array(img, dtype='uint8') # Pillow images should be closed after `load_img`, # but not PIL images. if hasattr(img, 'close'): img.close() if params.noaug_preprocessing_function is not None: kwargs = _add_optional_callback_arguments( params.noaug_preprocessing_function, batch_index=i, class_id=class_id, filename=filename, batch_class_ids=classes, batch_filenames=filenames ) x = params.noaug_preprocessing_function(params, x, **kwargs) if params.subset != 'validation' or params.image_data_generator.validation_augmentation_enabled: transform_params = params.image_data_generator.get_random_transform(x.shape) x = x.astype(dtype=np.float32) # float required to do transform below x = params.image_data_generator.apply_transform(x, transform_params) else: x = img_to_array(x).astype(dtype=np.float32) if params.preprocessing_function is not None: kwargs = _add_optional_callback_arguments( params.preprocessing_function, batch_index=i, class_id=class_id, filename=filename, batch_class_ids=classes, batch_filenames=filenames ) x = params.preprocessing_function(params, x, **kwargs) x = params.image_data_generator.standardize(x) # optionally save augmented images to disk for debugging purposes if params.save_to_dir: img = array_to_img(x, params.data_format, scale=True) now = int(time.time() * 1000) fname = f'{filename.replace("/", "-")}_{now}.{params.save_format}' if params.save_prefix: fname = f'{params.save_prefix}_{fname}' img.save(os.path.join(params.save_to_dir, fname)) return x for i, filename in enumerate(filenames): if isinstance(filename, (list,tuple)): for j, fn in enumerate(filename): batch_x[j][i] = _process_image_file(i, fn) else: batch_x[i] = _process_image_file(i, filename) # build batch of labels if params.class_mode == 'input': batch_y = batch_x.copy() elif params.class_mode in {'binary', 'sparse'}: batch_y = np.empty((batch_size,), dtype=params.dtype) for i, clazz in enumerate(classes): batch_y[i] = clazz elif params.class_mode == 'categorical': batch_y = np.zeros((batch_size, params.n_classes), dtype=params.dtype) for i, clazz in enumerate(classes): batch_y[i, clazz] = 1. else: return batch_index, batch_x return batch_index, (batch_x, batch_y) class BatchData: def __init__( self, n:int, shuffle:bool, pool:ProcessPool ): self.n = n self.shuffle = shuffle self.pool = pool self.batch_data = queue.Queue() if shuffle else {} self.batch_data_lock = threading.Condition() self.indices_lock = threading.Condition() self.indices = [] self.batch_counts = [] self.saved_batch_zero = None self.requests = [] self.data_event = threading.Event() @property def have_more_indices(self): with self.indices_lock: return (len(self.indices) + len(self.requests)) > 0 @property def request_count(self): with self.indices_lock: return len(self.requests) def start_batch(self): with self.indices_lock: self.batch_counts.append(self.n) self.indices = [i for i in range(self.n)] def next_index(self): with self.indices_lock: if len(self.requests) > 0: idx = self.requests.pop(0) try: self.indices.remove(idx) except: pass return idx else: return self.indices.pop(0) def wait(self): self.data_event.clear() while self.pool.is_running: if self.data_event.wait(timeout=.1): return True return False def reset(self): if self.shuffle: while not self.batch_data.empty(): self.batch_data.get() else: with self.batch_data_lock: self.batch_data.clear() with self.indices_lock: self.requests = [] self.indices = [i for i in range(self.n)] def qsize(self): if self.shuffle: return self.batch_data.qsize() else: with self.batch_data_lock: return len(self.batch_data) def put(self, index, value): if self.shuffle: self.batch_data.put(value) else: with self.batch_data_lock: self.batch_data[index] = value self.batch_data_lock.notify_all() def get(self, index, save_batch_zero=False): decrement_batch_count = True # If we're returning batch zero and we have a saved one, # then just return that batch if index == 0 and self.saved_batch_zero is not None: retval = self.saved_batch_zero self.saved_batch_zero = None elif self.shuffle: while True: if not self.pool.is_running: raise StopIteration('The data generator has been stopped') try: retval = self.batch_data.get(timeout=0.1) break except queue.Empty: continue else: with self.batch_data_lock: if index not in self.batch_data: with self.indices_lock: self.requests.append(index) self.data_event.set() while index not in self.batch_data: if not self.pool.is_running: raise StopIteration('The data generator has been stopped') self.batch_data_lock.wait(timeout=0.1) retval = self.batch_data[index] del self.batch_data[index] # If this is batch0 and we should save it # then saved a reference to it and do NOt decrement the batch count as it will be returned in a later call if index == 0 and save_batch_zero: self.saved_batch_zero = retval decrement_batch_count = False is_last_in_batch = self._decrement_current_batch_count() if decrement_batch_count else False self.data_event.set() return retval, is_last_in_batch def _decrement_current_batch_count(self) -> bool: is_last_in_batch = False with self.indices_lock: current_count = self.batch_counts[0] if current_count == 1: is_last_in_batch = True self.batch_counts.pop(0) else: self.batch_counts[0] = current_count - 1 return is_last_in_batch def _add_optional_callback_arguments( func, batch_index, class_id, filename, batch_class_ids, batch_filenames ) -> dict: retval = {} args = inspect.getfullargspec(func).args if 'batch_index' in args: retval['batch_index'] = batch_index if 'class_id' in args: retval['class_id'] = class_id if 'filename' in args: retval['filename'] = filename if 'batch_class_ids' in args: retval['batch_class_ids'] = batch_class_ids if 'batch_filenames' in args: retval['batch_filenames'] = batch_filenames return retval