Source code for mltk.core.tflite_micro.tflite_micro_accelerator
from typing import List
import logging
from mltk.core.tflite_model import TfliteModel
from mltk.utils.python import get_case_insensitive
[docs]class TfliteMicroAccelerator:
"""TF-Lite Micro Accelerator
This class allows for providing hardware-accelerated
kernels to the TFLM interpreter.
"""
[docs] def __init__(self, accelerator_wrapper):
self._accelerator_wrapper = accelerator_wrapper
self._active_variant:str = None
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
result._accelerator_wrapper = self._accelerator_wrapper
result._active_variant = self._active_variant
memo[id(self)] = result
return result
@property
def name(self) -> str:
"""The name of the accelerator"""
return self._accelerator_wrapper.name()
@property
def variants(self) -> List[str]:
"""List of variants supported by this accelerator"""
if hasattr(self._accelerator_wrapper, 'variants'):
return self._accelerator_wrapper.variants()
else:
return [self.name]
@property
def active_variant(self) -> str:
"""The name of the variant actively being used by this accelerator"""
return self._active_variant
@active_variant.setter
def active_variant(self, v:str):
v = get_case_insensitive(v, self.variants)
if not v:
raise ValueError('Unknown variant')
self._active_variant = v
@property
def api_version(self) -> int:
"""The API version number this wrapper was built with
This number must match the tflite_micro_wrapper's API version
"""
if hasattr(self._accelerator_wrapper, 'api_version'):
return self._accelerator_wrapper.api_version()
else:
return None
@property
def git_hash(self) -> str:
"""Return the GIT hash of the MLTK repo used to compile the wrapper library"""
if hasattr(self._accelerator_wrapper, 'git_hash'):
return self._accelerator_wrapper.git_hash()
else:
return None
@property
def accelerator_wrapper(self) -> object:
"""Return the TfliteMicroAcceleratorWrapper instance
"""
if hasattr(self._accelerator_wrapper, 'get_accelerator_wrapper'):
return self._accelerator_wrapper.get_accelerator_wrapper()
else:
return None
@property
def supports_model_compilation(self) -> bool:
"""Return if this accelerator supports model compilation"""
return type(self).compile_model != TfliteMicroAccelerator.compile_model
[docs] def estimate_profiling_results(
self,
results, # ProfilingModelResults
**kwargs
):
"""Update the given ProfilingModelResults with estimated model metrics"""
[docs] def set_program_recorder_enabled(self, enabled:bool):
"""Enable the accelerator instruction recorder"""
if hasattr(self._accelerator_wrapper, 'set_program_recorder_enabled'):
return self._accelerator_wrapper.set_program_recorder_enabled(enabled)
else:
return None
[docs] def enable_data_recorder(self):
"""Enable the accelerator data recorder"""
if hasattr(self._accelerator_wrapper, 'enable_data_recorder'):
return self._accelerator_wrapper.enable_data_recorder()
else:
return None
[docs] def compile_model(
self,
model:TfliteModel,
logger:logging.Logger=None,
report_path:str=None,
**kwargs
) -> TfliteModel:
"""Compile the given .tflite model and return a new TfliteModel instance with the compiled data
NOTE: The accelerator must support model compilation to use this API
"""
raise RuntimeError(f'The accelerator: {self.name} does not support model compilation')
class PlaceholderTfliteMicroAccelerator(TfliteMicroAccelerator):
"""This accelerator does NOT have a corresponding Python wrapper"""
def __init__(self, name:str):
self._name = name
super().__init__(None)
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
result._accelerator_wrapper = None
result._name = self._name
result._active_variant = self._active_variant
memo[id(self)] = result
return result
@property
def name(self) -> str:
return self._name
@property
def api_version(self) -> int:
from mltk.core.tflite_micro import TfliteMicro
return TfliteMicro.api_version()