Source code for mltk.core.preprocess.audio.parallel_generator.iterator

"""Utilities for real-time data augmentation on image data.
"""
import os
import sys
import random
import time
import threading
import inspect
import queue
from typing import List, Tuple
import numpy as np

from mltk.core import get_mltk_logger
from mltk.core.keras import DataSequence
from mltk.core.preprocess.utils import audio as audio_utils
from mltk.utils.process_pool import ProcessPool, calculate_n_jobs



class ParallelIterator(DataSequence):
    """Base class for image data iterators.

    # Arguments
        n: Integer, total number of samples in the dataset to loop over.
        batch_size: Integer, size of a batch.
        shuffle: Boolean, whether to shuffle the data between epochs.
        seed: Random seeding for data shuffling.
    """
    white_list_formats = ('wav')

    def __init__(self, n, batch_size, shuffle, seed, process_params):
        super().__init__()

        self.n = n
        self.batch_size = batch_size
        self.seed = seed
        self.total_batches_seen = 0
        self.batch_index = 0
        self.shuffle = shuffle
        self.process_params = process_params
        self.check_reuse_batch_zero = True

        get_batch_function = self.process_params.get_batch_function or get_batches_of_transformed_samples
        n_jobs = calculate_n_jobs(self.cores)
        # Use half the number of jobs for the validation subset
        if process_params.validation_split and process_params.subset == 'validation':
            n_jobs = max(int(n_jobs*.5), 1)

        self.pool = ProcessPool(
            name=self.process_params.subset,
            entry_point=get_batch_function,
            n_jobs=n_jobs,
            debug=self.debug,
            disable_gpu_in_subprocesses=self.disable_gpu_in_subprocesses,
            logger=get_mltk_logger()
        )

        self.batch_generation_started = threading.Event()
        self.current_batch_finished = threading.Event()
        self.current_batch_finished.set()

        self.batch_data = BatchData(
            len(self), 
            shuffle,
            pool=self.pool
        )
       
        self.batch_thread = threading.Thread(
            target=self._generate_batch_data_safe, 
            name=f'Batch data generator:{process_params.subset}',
            daemon=True
        )
        self.batch_thread.start()

 
    @property
    def is_shutdown(self) -> bool:
        return not self.pool.is_running

    @property
    def is_running(self) -> bool:
        return not self.is_shutdown

        
    def reset(self):
        self.batch_generation_started.clear()
        self.current_batch_finished.set()
        self.batch_data.reset()
        self.batch_index = 0
    
    
    def shutdown(self, wait=True):
        if wait:
            self.current_batch_finished.wait(30)
        self.reset()
        self.pool.shutdown()
       


    def __getitem__(self, idx):
        if idx >= len(self):
            raise ValueError(
                f'Asked to retrieve element {idx}, but the Sequence has length {len(self)}'
            )

        # Some TF APIs use a "peek_and_restore" operation on the first batch.
        # To account for this, we check the callback stack for a function that has "peek" in it
        # If so, we notify the batch_data.get() API that it should save this batch as it will be re-used
        save_batch_zero = False
        if idx == 0 and self.check_reuse_batch_zero:
            self.check_reuse_batch_zero = False
            callback_function_name = sys._getframe(1).f_code.co_name
            if 'peek' in callback_function_name:
                save_batch_zero = True

        self.batch_generation_started.set()
        self.current_batch_finished.clear()
        
        retval, is_last = self.batch_data.get(idx, save_batch_zero=save_batch_zero)
        if is_last:
            self.current_batch_finished.set()
        return retval


    def __len__(self):
        return (self.n + self.batch_size -1) // self.batch_size  # round up


    def on_epoch_end(self):
        pass


    def __iter__(self):
        # Needed if we want to do something like:
        # for x, y in data_gen.flow(...):
        self.batch_index = 0
        return self


    def __next__(self, *args, **kwargs):
        return self.next(*args, **kwargs)


    def next(self):
        """For python 2.x.

        # Returns
            The next batch.
        """
        
        if self.batch_index >= len(self):
            # Clear the started flag, but do NOT reset
            # This way we don't waste any processed batch data
            self.batch_generation_started.clear()
            raise StopIteration()
        
        self.batch_generation_started.set()
        self.current_batch_finished.clear()
        retval, is_last = self.batch_data.get(self.batch_index)
        if is_last:
            self.current_batch_finished.set()
        
        self.batch_index += 1
        
        return retval


    def _generate_batch_data_safe(self):
        try:
            self._generate_batch_data()
        except Exception as e:
            if not self.is_shutdown:
                get_mltk_logger().error(f'Exception during batch data processing, err: {e}', exc_info=e)
                return
            self.shutdown(wait=False)
    
    
    def _generate_batch_data(self):
        while self.is_running:
            # Wait for training to start
            if not self.batch_generation_started.wait(timeout=0.1):
                continue
        
            if self.seed is not None:
                np.random.seed(self.seed + self.total_batches_seen)
            
            # If the number of samples is not a multiple of the batch size
            # then the last batch needs to wrap to the beginning of the sample indices
            wrap_length = (len(self) * self.batch_size) - self.n
            if self.shuffle:
                index_array = np.random.permutation(self.n)
                if wrap_length > 0:
                    index_array = np.concatenate((index_array, np.random.permutation(wrap_length))) 
            else:
                index_array = np.arange(self.n)
                if wrap_length > 0:
                    index_array = np.concatenate((index_array, np.arange(wrap_length))) 
            

            self.batch_data.start_batch()
    
            while self.batch_data.have_more_indices:
                while self.batch_data.request_count == 0 and self.batch_data.qsize() > self.max_batches_pending:
                    if self.is_shutdown:
                        return
                    if not self.batch_generation_started.is_set():
                        break
                    self.batch_data.wait()
                
                if not self.batch_generation_started.is_set():
                    break
                
                current_batch_index = self.batch_data.next_index()
                self.total_batches_seen += 1
                offset = current_batch_index*self.batch_size
                batch_index_chunk = index_array[offset:offset+self.batch_size]
                
                batch_filenames = []
                batch_classes = []
                
                for batch_index in batch_index_chunk:
                    batch_filenames.append(self.filenames[batch_index])
                    batch_classes.append(self.classes[batch_index])

                self._invoke_processing(current_batch_index, batch_filenames, batch_classes)
    

    def _invoke_processing(
        self, 
        batch_index:int, 
        batch_filenames:List[str], 
        batch_classes:List[int]
    ):
        try:
            self.pool(
                batch_index, 
                batch_filenames, 
                batch_classes, 
                params=self.process_params,
                pool_callback=self._pool_callback
            )
        except Exception as e:
            if not self.is_running:
                raise


    def _pool_callback(self, results):
        if results is not None:
            self.batch_data.put(results[0], results[1])
            

[docs]class ParallelProcessParams(): """Adds methods related to getting batches from filenames It includes the logic to transform image files to batches. """
[docs] def __init__( self, audio_data_generator, sample_rate, sample_length_ms, sample_shape, save_to_dir, save_prefix, save_format, subset, class_indices, dtype, frontend_dtype, directory, class_mode, get_batch_function, noaug_preprocessing_function, preprocessing_function, postprocessing_function, frontend_enabled, add_channel_dimension ): self.class_indices = class_indices self.dtype = dtype self.frontend_dtype = frontend_dtype self.directory = directory self.class_mode = class_mode self.audio_data_generator = audio_data_generator self.get_batch_function = get_batch_function self.noaug_preprocessing_function = noaug_preprocessing_function self.preprocessing_function = preprocessing_function self.postprocessing_function = postprocessing_function self.frontend_enabled = frontend_enabled self.add_channel_dimension = add_channel_dimension self.sample_rate = sample_rate self.sample_length_ms = sample_length_ms self.sample_shape = sample_shape if frontend_enabled and len(self.sample_shape) == 2 and add_channel_dimension: self.sample_shape += (1,) # The 'depth' dimension to 1 self.save_to_dir = save_to_dir self.save_prefix = save_prefix self.save_format = save_format self.validation_split = audio_data_generator.validation_split if subset is not None: validation_split = audio_data_generator.validation_split if subset == 'validation': split = (0, validation_split) elif subset == 'training': split = (validation_split, 1) else: raise ValueError( f'Invalid subset name: {subset}; ' 'expected "training" or "validation"' ) else: split = None self.split = split self.subset = subset
def get_batches_of_transformed_samples( batch_index:int, filenames:List[str], classes:List[int], params:ParallelProcessParams ) -> Tuple[int, Tuple[np.ndarray, np.ndarray]]: """Gets a batch of transformed samples. Arguments: batch_index: Index of this batch filenames: List of filenames for this batch classes: List of class ids mapping to the filenames list params: Generator parameters # Returns A batch of transformed samples: batch_index, (batch_x, batch_y) """ batch_shape = (len(filenames),) + params.sample_shape batch_x = np.zeros(batch_shape, dtype=params.dtype) # build batch of image data # Ensure the RNG is unique for each batch if params.subset != 'validation' or params.audio_data_generator.validation_augmentation_enabled: random.seed(batch_index + int(time.time())) np.random.seed(batch_index + int(time.time())) for i, filename in enumerate(filenames): class_id = classes[i] if filename: filepath = os.path.join(params.directory, filename) x, orignal_sr = audio_utils.read_audio_file(filepath, return_sample_rate=True, return_numpy=True) else: orignal_sr = 16000 x = np.zeros((orignal_sr,), dtype='float32') # At this point, # x = [sample_length] dtype=float32 if params.noaug_preprocessing_function is not None: kwargs = _add_optional_callback_arguments( params.noaug_preprocessing_function, batch_index=i, class_id=class_id, filename=filename, batch_class_ids=classes, batch_filenames=filenames ) x = params.noaug_preprocessing_function(params, x, **kwargs) if params.subset != 'validation' or params.audio_data_generator.validation_augmentation_enabled: transform_params = params.audio_data_generator.get_random_transform() else: transform_params = params.audio_data_generator.default_transform # Apply any audio augmentations # NOTE: If transform_params = default_transform # Then the audio sample is simply cropped/padded to fit the expected sample length x = params.audio_data_generator.apply_transform(x, orignal_sr, transform_params) if params.preprocessing_function is not None: kwargs = _add_optional_callback_arguments( params.preprocessing_function, batch_index=i, class_id=class_id, filename=filename, batch_class_ids=classes, batch_filenames=filenames ) x = params.preprocessing_function(params, x, **kwargs) if params.frontend_enabled: # If a frontend dtype was specified use that, # otherwise just use the output dtype frontend_dtype = params.frontend_dtype or params.dtype # After point through the frontend, # x = [height, width] dtype=frontend_dtype x = params.audio_data_generator.apply_frontend(x, dtype=frontend_dtype) # Perform any post processing as necessary if params.postprocessing_function is not None: kwargs = _add_optional_callback_arguments( params.postprocessing_function, batch_index=i, class_id=class_id, filename=filename, batch_class_ids=classes, batch_filenames=filenames ) x = params.postprocessing_function(params, x, **kwargs) if params.frontend_enabled: # Do any standardizations (which are done using float32 internally) x = params.audio_data_generator.standardize(x) if params.add_channel_dimension: # Convert the sample's shape from [height, width] # to [height, width, 1] x = np.expand_dims(x, axis=-1) batch_x[i] = x # build batch of labels if params.class_mode == 'input': batch_y = batch_x.copy() elif params.class_mode in {'binary', 'sparse'}: batch_y = np.empty(len(batch_x), dtype=params.dtype) for i, class_id in enumerate(classes): batch_y[i] = class_id elif params.class_mode == 'categorical': batch_y = np.zeros((len(batch_x), len(params.class_indices)), dtype=params.dtype) for i, class_id in enumerate(classes): batch_y[i, class_id] = 1. else: return batch_index, batch_x return batch_index, (batch_x, batch_y) class BatchData: def __init__( self, n:int, shuffle:bool, pool:ProcessPool ): self.n = n self.shuffle = shuffle self.pool = pool self.batch_data = queue.Queue() if shuffle else {} self.batch_data_lock = threading.Condition() self.indices_lock = threading.Condition() self.indices = [] self.batch_counts = [] self.saved_batch_zero = None self.requests = [] self.data_event = threading.Event() @property def have_more_indices(self): with self.indices_lock: return (len(self.indices) + len(self.requests)) > 0 @property def request_count(self): with self.indices_lock: return len(self.requests) def start_batch(self): with self.indices_lock: self.batch_counts.append(self.n) self.indices = [i for i in range(self.n)] def next_index(self): with self.indices_lock: if len(self.requests) > 0: idx = self.requests.pop(0) try: self.indices.remove(idx) except: pass return idx else: return self.indices.pop(0) def wait(self): self.data_event.clear() while self.pool.is_running: if self.data_event.wait(timeout=.1): return True return False def reset(self): if self.shuffle: while not self.batch_data.empty(): self.batch_data.get() else: with self.batch_data_lock: self.batch_data.clear() with self.indices_lock: self.requests = [] self.indices = [i for i in range(self.n)] def qsize(self): if self.shuffle: return self.batch_data.qsize() else: with self.batch_data_lock: return len(self.batch_data) def put(self, index, value): if self.shuffle: self.batch_data.put(value) else: with self.batch_data_lock: self.batch_data[index] = value self.batch_data_lock.notify_all() def get(self, index, save_batch_zero=False): decrement_batch_count = True # If we're returning batch zero and we have a saved one, # then just return that batch if index == 0 and self.saved_batch_zero is not None: retval = self.saved_batch_zero self.saved_batch_zero = None elif self.shuffle: while True: if not self.pool.is_running: raise StopIteration('The data generator has been stopped') try: retval = self.batch_data.get(timeout=0.1) break except queue.Empty: continue else: with self.batch_data_lock: if index not in self.batch_data: with self.indices_lock: self.requests.append(index) self.data_event.set() while index not in self.batch_data: if not self.pool.is_running: raise StopIteration('The data generator has been stopped') self.batch_data_lock.wait(timeout=0.1) retval = self.batch_data[index] del self.batch_data[index] # If this is batch0 and we should save it # then saved a reference to it and do NOt decrement the batch count as it will be returned in a later call if index == 0 and save_batch_zero: self.saved_batch_zero = retval decrement_batch_count = False is_last_in_batch = self._decrement_current_batch_count() if decrement_batch_count else False self.data_event.set() return retval, is_last_in_batch def _decrement_current_batch_count(self) -> bool: is_last_in_batch = False with self.indices_lock: current_count = self.batch_counts[0] if current_count == 1: is_last_in_batch = True self.batch_counts.pop(0) else: self.batch_counts[0] = current_count - 1 return is_last_in_batch def _add_optional_callback_arguments( func, batch_index, class_id, filename, batch_class_ids, batch_filenames ) -> dict: retval = {} args = inspect.getfullargspec(func).args if 'batch_index' in args: retval['batch_index'] = batch_index if 'class_id' in args: retval['class_id'] = class_id if 'filename' in args: retval['filename'] = filename if 'batch_class_ids' in args: retval['batch_class_ids'] = batch_class_ids if 'batch_filenames' in args: retval['batch_filenames'] = batch_filenames return retval