Source code for mltk.core.tflite_model_parameters.flatbuffer_dictionary

# pylint: disable=unused-wildcard-import, wildcard-import
from typing import Union, List
import flatbuffers
import numpy as np

from .schema.Dictionary import *
from .schema.BinaryValue import *
from .schema.BoolValue import *
from .schema.DoubleValue import *
from .schema.Entry import *
from .schema.FloatValue import *
from .schema.Int16Value import *
from .schema.Int32Value import *
from .schema.Int64Value import *
from .schema.Int8Value import *
from .schema.StringList import *
from .schema.Int32List import * 
from .schema.FloatList import *
from .schema.StringValue import *
from .schema.Uint8Value import *
from .schema.Uint16Value import *
from .schema.Uint32Value import *
from .schema.Uint64Value import *
from .schema.Value import *


[docs]class FlatbufferDictionary(dict): """ FlatbufferDictionary This class allows for adding scalar values to a standard Python dictionary, serializing the dictionary into a flatbuffer, and later de-serializing to another Python dictionary. A flatbuffer dictionary is a collection of key/value pairs where: - **key** - Name of parameter as a string - **value** - Value of parameter as a specific scalar data type The dictionary is serialized using the `Flatbuffer <>`_ schema `dictionary.fbs <./dictionary.fbs>`_ .. note:: The FlatbufferDictionary object inherits the standard Python 'dict' class. """
[docs] def __init__(self, *args, **kwargs): dict.__init__(self, *args, **kwargs)
[docs] def put(self, key:str, value:Union[str,int,float,bool,List[str],bytes], dtype:str=None): """Put an entry into the dictionary This API allows for specifying the value's datatype. Alternatively, you can use the standard Python dictionary syntax, e.g.: my_params.set('foo', 42, 'int32') OR my_params['foo'] = 42 Args: key: The dictionary key to insert or overwrite value: The value of the entry. Must have a type of: str,int,float,bool,List[str], or bytes dtype: Optional. Force the value to have a specific data type. Must be a string and one of: bool,int8,int16,int32,int64,uint8,uint16,uint32,uint64,float,double,str,str_list,bin """ # Validate the value type and dtype _, value = _get_dtype_and_value(value, dtype) if dtype is not None: value = DictionaryValue(dtype=dtype, value=value) super(FlatbufferDictionary, self).__setitem__(key, value)
def __setitem__(self, key, value): self.put(key, value)
[docs] def serialize(self) -> bytes: """Serialize the current dictionary entries into a flatbuffer Returns: Serialized dictionary flatbuffer bytes """ builder = flatbuffers.Builder(0) entry_offsets = [] for key, value in self.items(): key_offset = builder.CreateString(key) try: value_type, value_offset = _generate_value(builder, value) except Exception as e: e.args = (f'Key: {key}, {e}', ) raise if not value_offset: continue EntryStart(builder) EntryAddKey(builder, key_offset) EntryAddValue(builder, value_offset) EntryAddValueType(builder, value_type) entry_offsets.append(EntryEnd(builder)) DictionaryStartEntriesVector(builder, len(entry_offsets)) for o in reversed(entry_offsets): builder.PrependUOffsetTRelative(o) entries_offset = builder.EndVector_Patched(len(entry_offsets)) DictionaryStart(builder) DictionaryAddSchemaVersion(builder, FLATBUFFER_SCHEMA_VERSION) DictionaryAddEntries(builder, entries_offset) root = DictionaryEnd(builder) builder.Finish(root) return bytes(builder.Output())
[docs] @staticmethod def deserialize(serialized_data:bytes): """Instantiate a FlatbufferDictionary object with the given serialized flatbuffer binary data""" # Load the flatbuffer serialized data params_fb = Dictionary.GetRootAsDictionary(serialized_data, 0) # Validate the schema version schema_version = params_fb.SchemaVersion() if not schema_version: raise RuntimeError('Flatbuffer missing schema version') if schema_version > FLATBUFFER_SCHEMA_VERSION: raise RuntimeError(f'Flatbuffer schema version ({schema_version}) not supported (max supported version:{FLATBUFFER_SCHEMA_VERSION})') # Instantiate a new TfliteModelParameters object model_parameters = FlatbufferDictionary() # Populate the TfliteModelParameters object with the # entries from the serialized flatbuffer data try: for i in range(params_fb.EntriesLength()): entry_fb = params_fb.Entries(i) key = entry_fb.Key().decode('utf-8') fb_value = entry_fb.Value() fb_type = entry_fb.ValueType() model_parameters[key] = _parse_value(fb_type, fb_value) except Exception as e: e.args = (f'Failed to parse flatbuffer, {e}', ) raise return model_parameters
[docs] def summary(self) -> str: """Generate a summary of the dictionary""" s = '' for key in sorted(self): value = self[key] if isinstance(value, list): value_str = ','.join([str(x) for x in value]) elif isinstance(value, bytes): value_str = f'<{len(value)} bytes>' else: value_str = f'{value}' if len(value_str) > 64: value_str = f'{value_str[:64]} ...' s += f'{key}: {value_str}\n' return s.strip()
def __str__(self): return self.summary()
class DictionaryValue: """Dictonary entry value A dictionary value contains both the actual scalar value and a corresponding data type """ def __init__(self, value, dtype): self.value = value self.dtype = dtype def __str__(self): return f'{self.dtype}:{self.value}' ####################################################################### # Internal flatbuffer parsing/generation code _TYPE_MAP = {} _TYPE_MAP['bool'] = Value.boolean _TYPE_MAP['int8'] = Value.i8 _TYPE_MAP['int16'] = Value.i16 _TYPE_MAP['int32'] = Value.i32 _TYPE_MAP['int64'] = Value.i64 _TYPE_MAP['uint8'] = Value.u8 _TYPE_MAP['uint16'] = Value.u16 _TYPE_MAP['uint32'] = Value.u32 _TYPE_MAP['uint64'] = Value.u64 _TYPE_MAP['float'] = Value.f32 _TYPE_MAP['double'] = Value.f64 _TYPE_MAP['str'] = Value.str _TYPE_MAP['str_list'] = Value.str_list _TYPE_MAP['int32_list'] = Value.int32_list _TYPE_MAP['float_list'] = Value.float_list _TYPE_MAP['bin'] = Value.bin INT_TYPES = (int, np.int8, np.uint8, np.int16, np.uint16, np.int32, np.uint32) FLOAT_TYPES = (float, np.float32) def _get_dtype_and_value(value, dtype=None) -> tuple: """Return the value's corresponding data type""" # pylint: disable=raise-missing-from if isinstance(value, DictionaryValue): dtype = value.dtype value = value.value if dtype is not None: if dtype not in _TYPE_MAP: raise ValueError(f'Unknown data type: {dtype}, valid types: {",".join(_TYPE_MAP.keys())}') if dtype == 'bool': try: bool(value) except Exception as e: raise ValueError(f'Failed to convert value to bool, err: {e}') elif dtype in ('int8', 'int16', 'int32', 'int64','uint8', 'uint16','uint32','uint64'): try: int(value) except Exception as e: raise ValueError(f'Failed to convert value to int, err: {e}') elif dtype in ('float','double'): try: float(value) except Exception as e: raise ValueError(f'Failed to convert value to float, err: {e}') elif dtype == 'str': try: str(value) except Exception as e: raise ValueError(f'Failed to convert value to str, err: {e}') elif dtype == 'str_list': if not isinstance(value, (list,tuple,set)): raise ValueError('dtype=str_list but value is not an instance of a list') for e in value: if not isinstance(e, str): raise ValueError('Only list of strings are supported') elif dtype == 'int32_list': if not isinstance(value, (list,tuple,set)): raise ValueError('dtype=int32_list but value is not an instance of a list') for e in value: if not isinstance(e, INT_TYPES): raise ValueError('Only list of integers are supported') elif dtype == 'float_list': if not isinstance(value, (list,tuple,set)): raise ValueError('dtype=float_list but value is not an instance of a list') for e in value: if not isinstance(e, FLOAT_TYPES): raise ValueError('Only list of floats are supported') elif dtype == 'bin': try: bytes(value) except Exception as e: raise ValueError(f'Failed to convert value to bytes, err: {e}') return dtype, value if isinstance(value, bool): dtype = 'bool' elif isinstance(value, INT_TYPES): if value < 0: if value >= -(2 ** 7): dtype = 'int8' elif value >= -(2 ** 15): dtype = 'int16' elif value >= -(2 ** 32): dtype = 'int32' elif value >= -(2 ** 63): dtype = 'int64' else: raise ValueError('Value overflow') else: if value < (2**8): dtype = 'uint8' elif value < (2**16): dtype = 'uint16' elif value < (2**32): dtype = 'uint32' elif value < (2**64): dtype = 'uint64' else: raise ValueError('Value overflow') elif isinstance(value, FLOAT_TYPES): if value >= 1.175494351e-38 and value <= 3.402823466e+38: dtype = 'float' else: dtype = 'double' elif isinstance(value, str): dtype = 'str' elif isinstance(value, (bytes, bytearray)): dtype = 'bin' elif isinstance(value, (list,tuple,set)): if len(value) == 0: # Default to a string list if the given list is empty dtype = 'str_list' elif isinstance(value[0], str): dtype = 'str_list' elif isinstance(value[0], INT_TYPES): dtype = 'int32_list' elif isinstance(value[0], FLOAT_TYPES): dtype = 'float_list' else: raise ValueError('Only list of strings, integers, or floats are supported') for e in value: if dtype == 'str_list' and not isinstance(e, str): raise ValueError('All entries in list must be strings') if dtype == 'int32_list' and not isinstance(e, INT_TYPES): raise ValueError('All entries in list must be integers') if dtype == 'float_list' and not isinstance(e, FLOAT_TYPES): raise ValueError('All entries in list must be floats') else: raise ValueError('Data type could not be automatically determined, you must manually specify the data type') return dtype, value def _generate_value(builder:flatbuffers.Builder, value) -> tuple: """Convert the Python value into a flatbuffer value""" dtype, value = _get_dtype_and_value(value) if dtype == 'null' or value is None: return 0, None dtype = _TYPE_MAP[dtype] if dtype == Value.boolean: BoolValueStart(builder) BoolValueAddValue(builder, bool(value)) offset = BoolValueEnd(builder) elif dtype == Value.i8: Int8ValueStart(builder) Int8ValueAddValue(builder, int(value)) offset = Int8ValueEnd(builder) elif dtype == Value.i16: Int16ValueStart(builder) Int16ValueAddValue(builder, int(value)) offset = Int16ValueEnd(builder) elif dtype == Value.i32: Int32ValueStart(builder) Int32ValueAddValue(builder, int(value)) offset = Int32ValueEnd(builder) elif dtype == Value.i64: Int64ValueStart(builder) Int64ValueAddValue(builder, int(value)) offset = Int64ValueEnd(builder) elif dtype == Value.u8: Uint8ValueStart(builder) Uint8ValueAddValue(builder, int(value)) offset = Uint8ValueEnd(builder) elif dtype == Value.u16: Uint16ValueStart(builder) Uint16ValueAddValue(builder, int(value)) offset = Uint16ValueEnd(builder) elif dtype == Value.u32: Uint32ValueStart(builder) Uint32ValueAddValue(builder, int(value)) offset = Uint32ValueEnd(builder) elif dtype == Value.u64: Uint64ValueStart(builder) Uint64ValueAddValue(builder, int(value)) offset = Uint64ValueEnd(builder) elif dtype == Value.f32: FloatValueStart(builder) FloatValueAddValue(builder, float(value)) offset = FloatValueEnd(builder) elif dtype == Value.f64: DoubleValueStart(builder) DoubleValueAddValue(builder, float(value)) offset = DoubleValueEnd(builder) elif dtype == Value.str: s_offset = builder.CreateString(str(value)) StringValueStart(builder) StringValueAddData(builder, s_offset) offset = StringValueEnd(builder) elif dtype == Value.bin: b_offset = builder.CreateByteVector(bytes(value)) BinaryValueStart(builder) BinaryValueAddData(builder, b_offset) offset = BinaryValueEnd(builder) elif dtype == Value.str_list: s_offsets = [] for s in value: s_offsets.append(builder.CreateString(str(s))) StringListStartDataVector(builder, len(s_offsets)) for o in reversed(s_offsets): builder.PrependUOffsetTRelative(o) v = builder.EndVector_Patched(len(s_offsets)) StringListStart(builder) StringListAddData(builder, v) offset = StringListEnd(builder) elif dtype == Value.int32_list: Int32ListStartDataVector(builder, len(value)) for o in reversed(value): builder.PrependInt32(int(o)) v = builder.EndVector_Patched(len(value)) Int32ListStart(builder) Int32ListAddData(builder, v) offset = Int32ListEnd(builder) elif dtype == Value.float_list: FloatListStartDataVector(builder, len(value)) for o in reversed(value): builder.PrependFloat32(float(o)) v = builder.EndVector_Patched(len(value)) FloatListStart(builder) FloatListAddData(builder, v) offset = FloatListEnd(builder) else: raise ValueError(f'Unknown data type: {dtype}') return dtype, offset def _parse_value(fb_type, fb_value): """Convert the flatbuffer value into a Python value""" if fb_type == Value.boolean: v = BoolValue() v.Init(fb_value.Bytes, fb_value.Pos) return v.Value() elif fb_type == Value.i8: v = Int8Value() v.Init(fb_value.Bytes, fb_value.Pos) return v.Value() elif fb_type == Value.i16: v = Int16Value() v.Init(fb_value.Bytes, fb_value.Pos) return v.Value() elif fb_type == Value.i32: v = Int32Value() v.Init(fb_value.Bytes, fb_value.Pos) return v.Value() elif fb_type == Value.i64: v = Int64Value() v.Init(fb_value.Bytes, fb_value.Pos) return v.Value() elif fb_type == Value.u8: v = Uint8Value() v.Init(fb_value.Bytes, fb_value.Pos) return v.Value() elif fb_type == Value.u16: v = Uint16Value() v.Init(fb_value.Bytes, fb_value.Pos) return v.Value() elif fb_type == Value.u32: v = Uint32Value() v.Init(fb_value.Bytes, fb_value.Pos) return v.Value() elif fb_type == Value.u64: v = Uint64Value() v.Init(fb_value.Bytes, fb_value.Pos) return v.Value() elif fb_type == Value.f32: v = FloatValue() v.Init(fb_value.Bytes, fb_value.Pos) return v.Value() elif fb_type == Value.f64: v = DoubleValue() v.Init(fb_value.Bytes, fb_value.Pos) return v.Value() elif fb_type == Value.str: v = StringValue() v.Init(fb_value.Bytes, fb_value.Pos) return v.Data().decode('utf-8') elif fb_type == Value.bin: v = BinaryValue() v.Init(fb_value.Bytes, fb_value.Pos) return v.DataAsNumpy().tobytes() elif fb_type == Value.str_list: retval = [] fb_vector = StringList() fb_vector.Init(fb_value.Bytes, fb_value.Pos) for i in range(fb_vector.DataLength()): s = fb_vector.Data(i).decode('utf-8') retval.append(s) return retval elif fb_type == Value.int32_list: retval = [] fb_vector = Int32List() fb_vector.Init(fb_value.Bytes, fb_value.Pos) for i in range(fb_vector.DataLength()): s = fb_vector.Data(i) retval.append(s) return retval elif fb_type == Value.float_list: retval = [] fb_vector = FloatList() fb_vector.Init(fb_value.Bytes, fb_value.Pos) for i in range(fb_vector.DataLength()): s = fb_vector.Data(i) retval.append(s) return retval else: raise RuntimeError(f'Unknown flatbuffer data type: {fb_type}') def _EndVector_Patched(self, l): """This works around a discrepancy between newer and older version of the flatbuffers.Builder.EndVector API """ try: return self.EndVector(l) except: return self.EndVector() flatbuffers.Builder.EndVector_Patched = _EndVector_Patched