Source code for mltk.core.evaluate_classifier

from typing import List, Tuple,Union
import logging
import json

import tqdm
import tensorflow as tf
from sklearn.metrics import (precision_recall_curve, confusion_matrix)
from sklearn.preprocessing import label_binarize
import numpy as np
import matplotlib.pyplot as plt
from mltk.utils import gpu
from mltk.utils.python import prepend_exception_msg
from .model import (
    model_utils,
    MltkModel,
    MltkModelEvent,
    KerasModel,
    TrainMixin,
    DatasetMixin,
    EvaluateClassifierMixin,
    load_tflite_or_keras_model,

)
from .tflite_model import TfliteModel
from .utils import get_mltk_logger
from .summarize_model import summarize_model
from .evaluation_results import EvaluationResults



[docs]class ClassifierEvaluationResults(EvaluationResults): """Classifier evaluation results .. seealso:: - :py:func:`~evaluate_classifier` - :py:func:`mltk.core.evaluate_model` """
[docs] def __init__(self, *args, **kwargs): EvaluationResults.__init__(self, *args, model_type='classification', **kwargs)
@property def classes(self) -> List[str]: """List of class labels used by evaluated model""" return self['classes'] @property def overall_accuracy(self) -> float: """The overall, model accuracy""" return self['overall_accuracy'] @property def class_accuracies(self) -> List[float]: """List of each classes' accuracy""" return self['class_accuracies'] @property def false_positive_rate(self) -> float: """The false positive rate""" return self['fpr'] @property def fpr(self) -> float: """The false positive rate""" return self['fpr'] @property def tpr(self) -> float: """The true positive rate""" return self['tpr'] @property def roc_auc(self) -> List[float]: """The area under the curve of the Receiver operating characteristic for each class""" return self['roc_auc'] @property def roc_thresholds(self) -> List[float]: """The list of thresholds used to calculate the Receiver operating characteristic""" return self['roc_thresholds'] @property def roc_auc_avg(self) -> List[float]: """The average of each classes' area under the curve of the Receiver operating characteristic""" return self['roc_auc_avg'] @property def precision(self) -> List[List[float]]: """List of each classes' precision at various thresholds""" return self['precision'] @property def recall(self) -> List[List[float]]: """List of each classes' recall at various thresholds""" return self['recall'] @property def confusion_matrix(self) -> List[List[float]]: """Calculated confusion matrix""" return self['confusion_matrix']
[docs] def calculate(self, y: Union[np.ndarray,list], y_pred: Union[np.ndarray,list]): """Calculate the evaluation results Given the expected y values and corresponding predictions, calculate the various evaluation results Args: y: 1D array with shape [n_samples] where each entry is the expected class label (aka id) for the corresponding sample e.g. 0 = cat, 1 = dog, 2 = goat, 3 = other y_pred: 2D array as shape [n_samples, n_classes] for categorical or 1D array as [n_samples] for binary, where each entry contains the model output for the given sample. For binary, the values must be between 0 and 1 where < 0.5 maps to class 0 and >= 0.5 maps to class 1 """ if not isinstance(y, np.ndarray): y = np.asarray(y) if not isinstance(y_pred, np.ndarray): y_pred = np.asarray(y_pred) if len(y_pred.shape) == 2: if y_pred.shape[1] == 1: y_pred = np.squeeze(y_pred, -1) if len(y_pred.shape) == 1: n_classes = 2 n_samples = len(y_pred) y_pred_orig = y_pred y_pred = np.zeros((n_samples, n_classes), dtype=np.float32) for i, pred in enumerate(y_pred_orig): class_id = 0 if pred < 0.5 else 1 y_pred[i][class_id] = pred else: n_classes = y_pred.shape[1] if 'classes' not in self or not self['classes']: self['classes'] = [str(x) for x in range(n_classes)] if len(y.shape) == 2: if y.shape[1] == 1: y = np.squeeze(y, -1) assert len(y) == len(y_pred), 'y and y_pred must have same number of samples' self['overall_accuracy'] = calculate_overall_accuracy(y_pred, y) self['class_accuracies'] = calculate_per_class_accuracies(y_pred, y) self['fpr'], self['tpr'], self['roc_auc'], self['roc_thresholds'] = calculate_auc(y_pred, y) self['roc_auc_avg'] = sum(self['roc_auc']) / n_classes self['precision'], self['recall'] = calculate_precision_recall(y_pred, y) self['confusion_matrix'] = calculate_confusion_matrix(y_pred, y)
[docs] def generate_summary(self) -> str: """Generate and return a summary of the results as a string""" s = super().generate_summary(include_all=False) return s + '\n' + summarize_results(self)
[docs] def generate_plots( self, show=True, output_dir:str=None, logger: logging.Logger=None ): """Generate plots of the evaluation results Args: show: Display the generated plots output_dir: Generate the plots at the specified directory. If omitted, generated in the model's logging directory logger: Optional logger """ plot_results( self, logger=logger, output_dir=output_dir, show=show )
[docs]def evaluate_classifier( mltk_model:MltkModel, tflite:bool=False, weights:str=None, max_samples_per_class:int=-1, classes:List[str]=None, verbose:bool=False, show:bool=False, update_archive:bool=True, **kwargs ) -> ClassifierEvaluationResults: """Evaluate a trained classification model Args: mltk_model: MltkModel instance tflite: If true then evalute the .tflite (i.e. quantized) model, otherwise evaluate the keras model weights: Optional weights to load before evaluating (only valid for a keras model) max_samples_per_class: Maximum number of samples per class to evaluate. This is useful for large datasets classes: Specific classes to evaluate verbose: Enable progress bar show: Show the evaluation results diagrams update_archive: Update the model archive with the eval results Returns: Dictionary containing evaluation results """ if not isinstance(mltk_model, TrainMixin): raise RuntimeError('MltkModel must inherit TrainMixin') if not isinstance(mltk_model, EvaluateClassifierMixin): raise RuntimeError('MltkModel must inherit EvaluateClassifierMixin') if not isinstance(mltk_model, DatasetMixin): raise RuntimeError('MltkModel must inherit a DatasetMixin') logger = mltk_model.create_logger('eval', parent=get_mltk_logger()) try: mltk_model.load_dataset( subset='evaluation', max_samples_per_class=max_samples_per_class, classes=classes, test=mltk_model.test_mode_enabled ) except Exception as e: prepend_exception_msg(e, 'Failed to load model evaluation dataset') raise try: # Build the MLTK model's corresponding as a Keras model or .tflite try: built_model = load_tflite_or_keras_model( mltk_model, model_type='tflite' if tflite else 'h5', weights=weights ) except Exception as e: prepend_exception_msg(e, 'Failed to build model') raise try: summary = summarize_model( mltk_model, built_model=built_model ) logger.info(summary) except Exception as e: logger.debug(f'Failed to generate model summary, err: {e}', exc_info=e) logger.warning(f'Failed to generate model summary, err: {e}') logger.info(mltk_model.summarize_dataset()) results = evaluate_classifier_with_built_model( mltk_model=mltk_model, built_model=built_model, verbose=verbose, show=show, logger=logger, update_archive=update_archive, ) finally: mltk_model.unload_dataset() return results
def evaluate_classifier_with_built_model( mltk_model:MltkModel, built_model:Union[KerasModel, TfliteModel], verbose:bool=False, show:bool=False, logger:logging.Logger = None, update_archive:bool=True, ) -> ClassifierEvaluationResults: """Evaluate a trained classification model with built model Args: mltk_model: MltkModel instance built_model: Built Keras or TfliteModel verbose: Enable progress bar show: Show the evaluation results diagrams update_archive: Update the model archive with the eval results logger: Optional python logger Returns: Dictionary containing evaluation results """ if update_archive: update_archive = mltk_model.check_archive_file_is_writable() subdir = 'eval/tflite' if isinstance(built_model, TfliteModel) else 'eval/h5' eval_dir = mltk_model.create_log_dir(subdir, delete_existing=True) logger = logger or mltk_model.create_logger('eval', parent=get_mltk_logger()) gpu.initialize(logger=logger) y_label, y_pred = generate_predictions( mltk_model=mltk_model, built_model=built_model, verbose=verbose ) results = ClassifierEvaluationResults( name=mltk_model.name, classes=getattr(mltk_model, 'classes', None) ) results.calculate( y=y_label, y_pred=y_pred, ) eval_results_path = f'{eval_dir}/eval-results.json' with open(eval_results_path, 'w') as f: json.dump(results, f) logger.debug(f'Generated {eval_results_path}') summary_path = f'{eval_dir}/summary.txt' with open(summary_path, 'w') as f: f.write(results.generate_summary()) logger.debug(f'Generated {summary_path}') results.generate_plots( logger=logger, output_dir=eval_dir, show=show ) if update_archive: try: logger.info(f'Updating {mltk_model.archive_path}') mltk_model.add_archive_dir(subdir) except Exception as e: logger.warning(f'Failed to add eval results to model archive, err: {e}', exc_info=e) if show: plt.show(block=True) return results def generate_predictions( mltk_model: MltkModel, built_model:Union[KerasModel, TfliteModel], verbose:bool=None ) -> Tuple[np.ndarray, np.ndarray]: """Generate predictions using evaluation data Args: mltk_model: MltkModel instance built_model: Built/trained Keras or TfliteModel verbose: Enable progress bar Returns: (y_label, y_pred) The evaluation sample labels and corresponding model predictions """ y_pred = [] y_label = [] with get_progbar(mltk_model, verbose) as progbar: for batch_x, batch_y in iterate_evaluation_data(mltk_model): if isinstance(built_model, KerasModel): pred = built_model.predict(batch_x, verbose=0) else: pred = built_model.predict(batch_x, y_dtype=np.float32) progbar.update(len(pred)) y_pred.extend(pred) if batch_y.shape[-1] == 1 or len(batch_y.shape) == 1: y_label.extend(batch_y) else: y_label.extend(np.argmax(batch_y, -1)) y_pred = list_to_numpy_array(y_pred) y_label = np.asarray(y_label, dtype=np.int32) return y_label, y_pred def plot_results( results:ClassifierEvaluationResults, show=True, output_dir:str=None, logger:logging.Logger=None ): """Use Matlibplot to generate plots of the evaluation results""" plot_roc(results, show=show, output_dir=output_dir, logger=logger) plot_precision_vs_recall(results, show=show, output_dir=output_dir, logger=logger) plot_tpr(results, show=show, output_dir=output_dir, logger=logger) plot_fpr(results, show=show, output_dir=output_dir, logger=logger) plot_tpr_and_fpr(results, show=show, output_dir=output_dir, logger=logger) plot_confusion_matrix(results, show=show, output_dir=output_dir, logger=logger) def summarize_results(results: ClassifierEvaluationResults) -> str: """Generate a summary of the evaluation results""" classes = results['classes'] class_accuracies = zip(classes, results['class_accuracies']) class_accuracies = sorted(class_accuracies, key=lambda x: x[1], reverse=True) class_auc = zip(classes, results['roc_auc']) class_auc = sorted(class_auc, key=lambda x: x[1], reverse=True) s = '' s += 'Overall accuracy: {:.3f}%\n'.format(results['overall_accuracy'] * 100) s += 'Class accuracies:\n' for class_label, acc in class_accuracies: s += '- {} = {:.3f}%\n'.format(class_label, acc * 100) s += 'Average ROC AUC: {:.3f}%\n'.format(results['roc_auc_avg'] * 100) s += 'Class ROC AUC:\n' for class_label, auc in class_auc: s += '- {} = {:.3f}%\n'.format(class_label, auc * 100) return s def calculate_overall_accuracy(y_pred:np.ndarray, y_label:np.ndarray) -> float: """ Classifier overall accuracy calculation y_pred contains model predictions [n_samples, n_classes] y_label list of each correct class id per sample [n_samples] Return overall accuracy (i.e. ratio) as float """ n_samples = len(y_pred) y_pred_label = np.argmax(y_pred, axis=1) correct = np.sum(y_label == y_pred_label) return correct / n_samples def calculate_per_class_accuracies(y_pred:np.ndarray, y_label:np.ndarray) -> List[float]: """Classifier accuracy per class y_pred contains model predictions [n_samples, n_classes] y_label list of each correct class id per sample [n_samples] Return list of each classes' accuracy """ n_samples, n_classes = y_pred.shape # Initialize array of accuracies accuracies = np.zeros(n_classes) # Loop on classes for class_id in range(n_classes): true_positives = 0 # Loop on all predictions for i in range(n_samples): # Check if it matches the class that we are working on if y_label[i] == class_id: # Get prediction label y_pred_label = np.argmax(y_pred[i,:]) # Check if the prediction is correct if y_pred_label == class_id: true_positives += 1 accuracies[class_id] = _safe_divide(true_positives, np.sum(y_label == class_id)) return accuracies.tolist() def calculate_auc(y_pred:np.ndarray, y_label:np.ndarray, threshold=.01) -> Tuple[float, float, List[float], List[float]]: """Classifier ROC AUC calculation y_pred contains model predictions [n_samples, n_classes] y_label list of each correct class id per sample [n_samples] thresholds Optional list of thresholds to consider Return tuple: false positive rate, true positive rate, list ROC AUC for each class, list of thresholds """ n_samples, n_classes = y_pred.shape # thresholds, linear range thresholds = np.arange(0.0, 1.01, threshold) n_thresholds = len(thresholds) # false positive rate fpr = np.zeros((n_classes, n_thresholds)) # true positive rate tpr = np.zeros((n_classes, n_thresholds)) # area under curve roc_auc = np.zeros(n_classes) # get number of positive and negative examples in the dataset for class_item in range(n_classes): # Sum of all true positive answers all_positives = sum(y_label == class_item) # Sum of all true negative answers all_negatives = len(y_label) - all_positives # iterate through all thresholds and determine fraction of true positives # and false positives found at this threshold for threshold_item in range(1, n_thresholds): threshold = thresholds[threshold_item] false_positives = 0 true_positives = 0 for i in range(n_samples): # Check prediction for this threshold if (y_pred[i, class_item] > threshold): if y_label[i] == class_item: true_positives += 1 else: false_positives += 1 fpr[class_item, threshold_item] = _safe_divide(false_positives, float(all_negatives)) tpr[class_item, threshold_item] = _safe_divide(true_positives, float(all_positives)) # Force boundary condition fpr[class_item,0] = 1 tpr[class_item,0] = 1 # calculate area under curve, trapezoid integration for threshold_item in range(len(thresholds)-1): roc_auc[class_item] += .5*(tpr[class_item,threshold_item]+tpr[class_item,threshold_item+1])*(fpr[class_item,threshold_item]-fpr[class_item,threshold_item+1]) return fpr.tolist(), tpr.tolist(), roc_auc.tolist(), thresholds.tolist() def calculate_precision_recall(y_pred:np.ndarray, y_label:np.ndarray) -> Tuple: """Calculate precision and recall """ _, n_classes = y_pred.shape precision = [None] * n_classes recall = [None] * n_classes y_true = _label_binarize(y_label) for class_id in range(n_classes): class_precision, class_recall, _ = precision_recall_curve(y_true[:, class_id], y_pred[:, class_id]) precision[class_id] = class_precision.tolist() recall[class_id] = class_recall.tolist() return precision, recall def calculate_confusion_matrix(y_pred:np.ndarray, y_label:np.ndarray): """Calculate the confusion matrix """ y_true = _label_binarize(y_label) cm_npy = confusion_matrix(np.argmax(y_true, axis=1), np.argmax(y_pred, axis=1)) return cm_npy.tolist() def plot_roc(results:dict, output_dir:str, show:bool, logger: logging.Logger): """Generate a plot of the AUC ROC evaluation results""" name = results['name'] classes = results['classes'] fpr = results['fpr'] tpr = results['tpr'] roc_auc = results['roc_auc'] n_classes = len(classes) title = f'ROC: {name}' fig = plt.figure(title, figsize=(10,8)) for i in range(n_classes): plt.plot(fpr[i], tpr[i], label=f'AUC: {roc_auc[i]:0.5f} ({classes[i]})') plt.plot([], [], ' ', label='Average AUC: {:.5f}'.format(results['roc_auc_avg'])) plt.xlim([0.0, 0.1]) plt.ylim([0.5, 1.01]) plt.legend(loc="lower right") plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title(title, y=0) plt.grid(which='major') plt.tight_layout() _trigger_event(fig=fig, name='roc', output_dir=output_dir, logger=logger) if output_dir: output_path = output_dir + f'/{name}-roc.png' plt.savefig(output_path) logger.debug(f'Generated {output_path}') if show: plt.show(block=False) else: fig.clear() plt.close(fig) def plot_precision_vs_recall(results:dict, output_dir:str, show, logger: logging.Logger): """Generate a plot of the precision vs recall""" name = results['name'] classes = results['classes'] precision = results['precision'] recall = results['recall'] n_classes = len(classes) title = f'Precision vs Recall: {name}' fig = plt.figure(title, figsize=(10,8)) for i in range(n_classes): plt.plot(recall[i], precision[i], label=_normalize_class_name(classes[i])) plt.xlim([0.5, 1.0]) plt.ylim([0.5, 1.01]) plt.legend(loc="lower left") plt.xlabel('Recall') plt.ylabel('Precision') plt.title(title) plt.grid() plt.tight_layout() _trigger_event(fig=fig, name='precision_vs_recall', output_dir=output_dir, logger=logger) if output_dir: output_path = output_dir + f'/{name}-precision_vs_recall.png' plt.savefig(output_path) logger.debug(f'Generated {output_path}') if show: plt.show(block=False) else: fig.clear() plt.close(fig) def plot_tpr(results:dict, output_dir:str, show, logger: logging.Logger): """Generate a plot of the threshold vs TPR""" name = results['name'] classes = results['classes'] tpr = results['tpr'] thresholds = results['roc_thresholds'] n_classes = len(classes) title = f'Thres vs True Positive: {name}' fig = plt.figure(title, figsize=(10,8)) for i in range(n_classes): plt.plot(thresholds, tpr[i], label=_normalize_class_name(classes[i])) plt.xlim([0.0, 1.0]) plt.ylim([0.8, 1.01]) plt.legend(loc="lower left") plt.xlabel('Threshold') plt.ylabel('True Positive Rate') plt.title(title) plt.grid() plt.tight_layout() _trigger_event(fig=fig, name='tpr', output_dir=output_dir, logger=logger) if output_dir: output_path = output_dir + f'/{name}-tpr.png' plt.savefig(output_path) logger.debug(f'Generated {output_path}') if show: plt.show(block=False) else: fig.clear() plt.close(fig) def plot_fpr(results:dict, output_dir:str, show, logger: logging.Logger): """Generate a plot of the threshold vs FPR""" name = results['name'] classes = results['classes'] fpr = results['fpr'] thresholds = results['roc_thresholds'] n_classes = len(classes) title = f'Thres vs False Positive: {name}' fig = plt.figure(title, figsize=(10,8)) for i in range(n_classes): plt.plot(thresholds, fpr[i], label=_normalize_class_name(classes[i])) plt.xlim([0.0, 1.0]) plt.ylim([0.0, 0.1]) plt.legend(loc="upper right") plt.xlabel('Threshold') plt.ylabel('False Positive Rate') plt.title(title) plt.grid() plt.tight_layout() _trigger_event(fig=fig, name='fpr', output_dir=output_dir, logger=logger) if output_dir: output_path = output_dir + f'/{name}-fpr.png' plt.savefig(output_path) logger.debug(f'Generated {output_dir}') if show: plt.show(block=False) else: fig.clear() plt.close(fig) def plot_tpr_and_fpr(results:dict, output_dir:str, show, logger: logging.Logger): """Generate a plot of the threshold vs FPR""" name = results['name'] classes = results['classes'] tpr = results['tpr'] fpr = results['fpr'] thresholds = results['roc_thresholds'] n_classes = len(classes) title = f'Thres vs True/False Positive: {name}' fig = plt.figure(title, figsize=(10,8)) for i in range(n_classes): plt.plot(thresholds, fpr[i], label=f'FPR: {_normalize_class_name(classes[i])}') plt.plot(thresholds, tpr[i], label=f'TPR: {_normalize_class_name(classes[i])}') plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.0]) plt.legend(loc="center left", bbox_to_anchor=(1, 0)) plt.xlabel('Threshold') plt.ylabel('True/False Positive Rate') plt.title(title) plt.grid() plt.tight_layout() _trigger_event(fig=fig, name='tpr_and_fpr', output_dir=output_dir, logger=logger) if output_dir: output_path = output_dir + f'/{name}-tfp_fpr.png' plt.savefig(output_path, bbox_inches='tight') logger.debug(f'Generated {output_dir}') if show: plt.show(block=False) else: fig.clear() plt.close(fig) def plot_confusion_matrix(results:dict, output_dir:str, show, logger: logging.Logger): """Generate a plot of the confusion matrix""" name = results['name'] classes = results['classes'] cm = results['confusion_matrix'] n_classes = len(classes) title = f'Confusion Matrix: {name}' fig = plt.figure(title, figsize=(6,6)) ax = fig.subplots() ax.imshow(cm) # We want to show all ticks ax.set_xticks(np.arange(n_classes)) ax.set_yticks(np.arange(n_classes)) # ... and label them with the respective list entries ax.set_xticklabels(classes) ax.set_yticklabels(classes) # Rotate the tick labels and set their alignment. plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Loop over data dimensions and create text annotations. for i in range(n_classes): for j in range(n_classes): ax.text(j, i, cm[i][j], ha="center", va="center", color="w", backgroundcolor=(0.41, 0.41, 0.41, 0.25)) ax.set_ylabel('Actual class') ax.set_xlabel('Predicted class') plt.title(title) plt.tight_layout() _trigger_event(fig=fig, name='confusion_matrix', output_dir=output_dir, logger=logger) if output_dir: output_path = output_dir + f'/{name}-confusion_matrix.png' plt.savefig(output_path) logger.debug(f'Generated {output_dir}') if show: plt.show(block=False) else: fig.clear() plt.close(fig) def _trigger_event(fig, name:str, output_dir:str, logger:logging.Logger): model_utils.trigger_model_event( MltkModelEvent.GENERATE_EVALUATE_PLOT, name=name, tflite=output_dir and output_dir.endswith('tflite'), fig=fig, logger=logger ) def _safe_divide(num, dem): """Standard division but it denominator is 0 then return 0""" if dem == 0: return 0 else: return num / dem def _label_binarize(y_label): """This calls label_binarize() but ensures the return value always has the shape: [n_samples, n_classes]""" # If n_classes == 2: # y_true = (n_samples, 1) # else: # y_true = (n_samples, n_classes) y_true = label_binarize(y=y_label, classes=np.arange(np.max(y_label)+1)) # Handle case with only 2 classes if y_true.shape[1] == 1: y_tmp = np.empty((y_true.shape[0], 2), dtype=y_true.dtype) y_tmp[:, 0] = 1-y_true[:, 0] y_tmp[:, 1] = y_true[:, 0] y_true = y_tmp return y_true def _normalize_class_name(label:str) -> str: if label.startswith('_'): label = label[1:] if label.endswith('_'): label = label[:-1] return label def iterate_evaluation_data(mltk_model:MltkModel): """Iterate over the given MltkModel's evaluation data This returns a generator, which generates tuples of (batch_x, batch_y), numpy arrays of the evaluation data provided by the given mltk_model instance. """ x = mltk_model.validation_data if x is None: x = mltk_model.x y = mltk_model.y if y is not None: if isinstance(x, tf.Tensor): x = x.numpy() y = y.numpy() if isinstance(x, np.ndarray): yield x, y else: for batch_x, batch_y in zip(x, y): batch_x = _convert_tf_tensor_to_numpy_array(batch_x, expand_dim=0) batch_y = _convert_tf_tensor_to_numpy_array(batch_y, expand_dim=0) yield batch_x, batch_y else: for batch in x: batch_x, batch_y, _ = tf.keras.utils.unpack_x_y_sample_weight(batch) batch_x = _convert_tf_tensor_to_numpy_array(batch_x) batch_y = _convert_tf_tensor_to_numpy_array(batch_y) yield batch_x, batch_y def list_to_numpy_array(python_list:List[np.ndarray], dtype=None) -> np.ndarray: """Convert the given Python list of numpy arrays to a single numpy array""" n_samples = len(python_list) if len(python_list[0].shape) > 0: numpy_array_shape = (n_samples,) + python_list[0].shape else: numpy_array_shape = (n_samples,) numpy_array = np.empty(numpy_array_shape, dtype=dtype or python_list[0].dtype) for i, pred in enumerate(python_list): numpy_array[i] = pred return numpy_array def _convert_tf_tensor_to_numpy_array(x, expand_dim=None): if isinstance(x, tf.Tensor): x = x.numpy() elif isinstance(x, (list,tuple)): if isinstance(x[0], np.ndarray) and expand_dim is not None: return x retval = [] for i in x: retval.append(_convert_tf_tensor_to_numpy_array(i, expand_dim=expand_dim)) return tuple(retval) if expand_dim is not None: x = np.expand_dims(x, axis=expand_dim) return x class ClassifierEvaluationProgressBar: def __init__(self, n_samples) -> None: if n_samples: self.progbar = tqdm.tqdm(unit='prediction', desc='Evaluating', total=n_samples) else: self.progbar = None def update(self, n): if self.progbar is not None: self.progbar.update(n) def __enter__(self): return self def __exit__(self, *args, **kwargs): pass def get_progbar(mltk_model:MltkModel, verbose:bool) -> ClassifierEvaluationProgressBar: """Return a tqdm progessbar if verbose=True If verbose=False, return an empty ClassifierEvaluationProgressBar """ try: class_counts = getattr(mltk_model, 'class_counts', {}) eval_class_counts = class_counts.get('evaluation', {}) valid_class_counts = class_counts.get('validation', {}) eval_n_samples = sum(eval_class_counts.values()) valid_n_samples = sum(valid_class_counts.values()) if eval_n_samples > 0: n_samples = eval_n_samples elif valid_n_samples > 0: n_samples = valid_n_samples else: n_samples = sum(class_counts.values()) or None except: n_samples = None return ClassifierEvaluationProgressBar(n_samples if verbose else None)