Source code for mltk.core.tflite_model.tflite_model
from__future__importannotationsimportosimportwarningsfromtypingimportList,Dict,Union,IteratorfromprettytableimportPrettyTableimportnumpyasnp# Disable the "DeprecationWarning" found in the flatbuffer packagewarnings.filterwarnings("ignore",category=DeprecationWarning)from.importtflite_schemaas_tflite_schema_fbfrom.tflite_schemaimportBuiltinOperatorasTfliteOpCode# pylint: disable=unused-importfrom.tflite_schemaimportflatbuffersfrom.tflite_tensorimportTfliteTensorfrom.tflite_layerimportTfliteLayerTFLITE_FILE_IDENTIFIER=b"TFL3"
[docs]classTfliteModel:"""Class to access a .tflite model flatbuffer's layers and tensors Refer to `schema_v3.fbs <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema_v3.fbs>`_ for more details on the .tflite flatbuffer schema **Example Usage** .. highlight:: python .. code-block:: python from mltk.core import TfliteModel # Load you .tflite model file model = TfliteModel.load_flatbuffer_file('some/path/my_model.tflite') # Print a summary of the model print(tflite_model.summary()) # Iterate through each layer of the model for layer in tflite_model.layers: # See TfliteLayer for additional info print(layer) # Update the model's description # This updates the .tflite's "description" field (which will be displayed in GUIs like https://netron.app) tflite_model.description = "My awesome model" print(f'New model description: {tflite_model.description}') # Save a new .tflite with the updated description tflite_model.save('some/path/my_new_model.tflite') # Add some metadata to the .tflite metadata = 'this is metadata'.encode('utf-8') tflite_model.add_metadata('my_metadata', metadata) # Retrieve all the metadata in the .tflite all_metadata = tflite_model.get_all_metadata() for key, data in all_metadata.items(): print(f'{key}: length={len(data)} bytes') # Save a new .tflite with the updated metadata tflite_model.save('some/path/my_new_model.tflite') # You must have Tensorflow instance to perform this step # This will run inference with the given buffer and return # the results. The input_buffer can be: # - a single sample as a numpy array # - a numpy array of 1 or more samples # - A Python generator that returns (batch_x, batch_y) # inference_results = tflite_model.predict(..) """
[docs]@staticmethoddefload_flatbuffer_file(path:str,cwd=None)->TfliteModel:"""Load a .tflite flatbuffer file"""found_path=_existing_path(path,cwd=cwd)iffound_pathisNone:raiseFileNotFoundError(f'.tflite model file not found: {path}')withopen(found_path,'rb')asf:flatbuffer_data=f.read()returnTfliteModel(flatbuffer_data=flatbuffer_data,path=found_path)
@propertydefpath(self)->str:"""Path to .tflite file Returns None if no path was specified. The path is normalized and backslashes are converted to forward slash """returnNoneifself._pathisNoneelseos.path.normpath(self._path).replace('\\','/')@path.setterdefpath(self,v:str):"""Path to .tflite file"""ifvisnotNone:v=v.replace('\\','/')self._path=v@propertydeffilename(self)->str:"""File name of associated .tflite model file Return None if not path is set"""ifself._path:returnos.path.basename(self._path)else:returnNone@propertydefname(self)->str:"""The name of the model which is the :py:attr:`~filename` without the ``.tflite`` extension or "my_model" if no path is set"""filename=self.filenameiffilename:iffilename.endswith('.tflite'):filename=filename[:-len('.tflite')]returnfilenameelse:return'my_model'@propertydefdescription(self)->str:"""Get/set model description .. note:: :py:func:`~save` must be called for changes to persist """return''ifself._modelisNoneornotself._model.descriptionelseself._model.description.decode('utf-8')@description.setterdefdescription(self,desc:str):ifself._modelisNone:raiseRuntimeError('Model not loaded')desc=descor''self._model.description=desc.encode('utf-8')self.regenerate_flatbuffer()@propertydefflatbuffer_data(self)->bytes:"""Flatbuffer binary data"""ifself._flatbuffer_dataisNone:returnNonereturnbytes(self._flatbuffer_data)@propertydefflatbuffer_size(self)->int:"""Size of the model flatbuffer in bytes"""ifself.flatbuffer_dataisNone:return0returnlen(self.flatbuffer_data)def__len__(self)->int:returnself.flatbuffer_size@propertydefflatbuffer_model(self)->_tflite_schema_fb.ModelT:"""Flatbuffer schema Model object"""returnself._model@propertydefflatbuffer_subgraph(self)->_tflite_schema_fb.SubGraphT:"""Flatbuffer schema model subgraph"""ifself._modelisNone:returnNonereturnself._model.subgraphs[self._selected_model_subgraph_index]@propertydefselected_model_subgraph(self)->int:"""The index of the selected model subgraph. Other properties and APIs will return layers/tensors from the selected subgraph """returnself._selected_model_subgraph_index@selected_model_subgraph.setterdefselected_model_subgraph(self,v:int):ifself._modelisNone:return-1ifv<0orv>=self.n_subgraphs:raiseValueError('Invalid model subgraph index')self._selected_model_subgraph_index=v@propertydefn_subgraphs(self)->int:"""Return the number of model subgraphs"""ifself._modelisNone:return0returnlen(self._model.subgraphs)@propertydefn_inputs(self)->int:"""Return the number of model inputs"""ifself.flatbuffer_subgraphisNone:return0returnlen(self.flatbuffer_subgraph.inputs)@propertydefinputs(self)->List[TfliteTensor]:"""List of all input tensors"""ifself.flatbuffer_subgraphisNone:returnNoneretval=[]forindexinself.flatbuffer_subgraph.inputs:retval.append(self.get_tensor(index))returnretval@propertydefn_outputs(self)->int:"""Return the number of model outputs"""ifself.flatbuffer_subgraphisNone:return0returnlen(self.flatbuffer_subgraph.outputs)@propertydefoutputs(self)->List[TfliteTensor]:"""List of all output tensors"""ifself.flatbuffer_subgraphisNone:returnNoneretval=[]forindexinself.flatbuffer_subgraph.outputs:retval.append(self.get_tensor(index))returnretval@propertydeflayers(self)->List[TfliteLayer]:"""List of all model layers for the current subgraph"""ifself._selected_model_subgraph_index==-1:returnNonereturnself._subgraphs[self._selected_model_subgraph_index].layers@propertydeftensors(self)->List[TfliteTensor]:"""List of all model tensors for the current subgraph"""ifself._selected_model_subgraph_index==-1:returnNonereturnself._subgraphs[self._selected_model_subgraph_index].tensors
[docs]defsummary(self)->str:"""Generate a summary of the model"""ifself._flatbuffer_dataisNone:return'Not loaded't=PrettyTable()t.field_names=['Index','OpCode','Input(s)','Output(s)','Config']fori,layerinenumerate(self.layers):inputs='\n'.join([x.shape_dtype_str(include_batch=False)forxinlayer.inputsifxisnotNone])outputs='\n'.join([x.shape_dtype_str(include_batch=False)forxinlayer.outputsifxisnotNone])t.add_row([i,layer.opcode_str,inputs,outputs,f'{layer.options}'])t.align='l'returnt.get_string()
[docs]defget_flatbuffer_subgraph(self,index:int=None)->_tflite_schema_fb.SubGraphT:"""Flatbuffer schema model subgraph at the given index If no index is given, then use the selected_model_subgraph """ifself._modelisNone:raiseRuntimeError('Model not loaded')index=indexorself._selected_model_subgraph_indexreturnself._model.subgraphs[index]
[docs]defget_tensor(self,index:int)->TfliteTensor:"""Return a specific model tensor as a TfliteTensor """ifself._modelisNone:raiseRuntimeError('Model not loaded')subgraph=self._subgraphs[self._selected_model_subgraph_index]ifindex>=len(subgraph.tensors):raiseIndexError(f'Index overflow ({index} >= {len(subgraph.tensors)})')returnsubgraph.tensors[index]
[docs]defget_tensor_data(self,index:int)->np.ndarray:"""Return a specific model tensor as a np.ndarray """tensor=self.get_tensor(index=index)iftensorisNone:returnNonereturntensor.data
[docs]defget_input_tensor(self,index:int=0)->TfliteTensor:"""Return a model input tensor as a TfliteTensor"""ifindex>=self.n_inputs:raiseIndexError(f'Index overflow ({index} >= {self.n_inputs})')tensor_index=self.flatbuffer_subgraph.inputs[index]returnself.get_tensor(tensor_index)
[docs]defget_input_data(self,index:int=0)->np.ndarray:"""Return a model input as a np.ndarray"""ifindex>=self.n_inputs:raiseIndexError(f'Index overflow ({index} >= {self.n_inputs})')tensor_index=self.flatbuffer_subgraph.inputs[index]returnself.get_tensor_data(tensor_index)
[docs]defget_output_tensor(self,index:int=0)->TfliteTensor:"""Return a model output tensor as a TfliteTensor"""ifindex>=self.n_outputs:raiseIndexError(f'Index overflow ({index} >= {self.n_outputs})')tensor_index=self.flatbuffer_subgraph.outputs[index]returnself.get_tensor(tensor_index)
[docs]defget_output_data(self,index:int=0)->np.ndarray:"""Return a model output tensor as a np.ndarray"""ifindex>=self.n_outputs:raiseIndexError(f'Index overflow ({index} >= {self.n_outputs})')tensor_index=self.flatbuffer_subgraph.outputs[index]returnself.get_tensor_data(tensor_index)
[docs]defget_all_metadata(self)->Dict[str,bytes]:"""Return all model metadata as a dictionary"""ifself._modelisNone:raiseRuntimeError('Model not loaded')retval={}formetadatainself._model.metadata:name=metadata.name.decode("utf-8")buffer_index=metadata.bufferretval[name]=self._model.buffers[buffer_index].data.tobytes()returnretval
[docs]defget_metadata(self,tag:str)->bytes:"""Return model metadata with specified tag"""ifself._modelisNone:raiseRuntimeError('Model not loaded')metadata_value=Noneformetadatainself._model.metadata:ifmetadata.name.decode("utf-8")==tag:buffer_index=metadata.buffermetadata_value=self._model.buffers[buffer_index].data.tobytes()breakreturnmetadata_value
[docs]defadd_metadata(self,tag:str,value:bytes):"""Set or add metadata to model .. Note:: :func:`~tflite_model.TfliteModel.save` must be called for changes to persist Args: tag (str): The key to use to lookup the metadata value (bytes): The metadata value as a binary blob to add to the .tflite """ifself._modelisNone:raiseRuntimeError('Model not loaded')ifnottagornotvalue:raiseValueError('Must provide valid tag and value arguments')buffer_field=_tflite_schema_fb.BufferT()buffer_field.data=np.frombuffer(value,dtype=np.uint8)add_buffer=Falseifnotself._model.metadata:self._model.metadata=[]else:# Check if metadata has already been add to the model.formetainself._model.metadata:ifmeta.name.decode("utf-8")==tag:add_buffer=Trueself._model.buffers[meta.buffer]=buffer_fieldifnotadd_buffer:ifnotself._model.buffers:self._model.buffers=[]self._model.buffers.append(buffer_field)# Creates a new metadata field.metadata_field=_tflite_schema_fb.MetadataT()metadata_field.name=tag.encode('utf-8')metadata_field.buffer=len(self._model.buffers)-1self._model.metadata.append(metadata_field)self.regenerate_flatbuffer()
[docs]defremove_metadata(self,tag:str)->bool:"""Remove model metadata with specified tag .. Note:: :func:`~tflite_model.TfliteModel.save` must be called for changes to persist Args: tag (str): The key to use to lookup the metadata Return: True if the metadata was found and removed, False else """ifself._modelisNone:raiseRuntimeError('Model not loaded')ifnotself._model.metadata:returnFalseremoved_metadata=Falseformetainself._model.metadata:ifmeta.name.decode("utf-8")==tag:removed_metadata=Trueself._model.metadata.remove(meta)self._model.buffers.pop(meta.buffer)self.regenerate_flatbuffer()breakreturnremoved_metadata
[docs]defsave(self,output_path:str=None,update_path=False):"""Save flatbuffer data to file If output_path is specified then write to new file, otherwise overwrite existing file """output_path=output_pathorself.pathifnotoutput_path:raiseRuntimeError('No output path specified')# Re-generate the underlying flatbufferself.regenerate_flatbuffer()# Create the model's output directory if necessaryout_dir=os.path.dirname(output_path)ifout_dir:os.makedirs(out_dir,exist_ok=True)withopen(output_path,'wb')asf:f.write(self._flatbuffer_data)ifupdate_path:self._path=output_path
[docs]defregenerate_flatbuffer(self,reload_model=False):"""Re-generate the underlying flatbuffer based on the information cached in the local ModelT instance .. Note:: :func:`~tflite_model.TfliteModel.save` must be called for changes to persist """ifself._modelisNone:raiseRuntimeError('Model not loaded')b=flatbuffers.Builder(0)b.Finish(self._model.Pack(b),TFLITE_FILE_IDENTIFIER)self._flatbuffer_data=b.Output()ifreload_model:self._load_model()
[docs]defpredict(self,x:Union[np.ndarray,Iterator,List[np.ndarray],Dict[int,np.ndarray]],y_dtype=None,**kwargs)->np.ndarray:"""Invoke the TfLite interpreter with the given input sample and return the results If the model has a single input and output, the x data can one of: - A single sample as a numpy array - Iterable list of samples - Sample generator In this case, this API will manage converting the samples to the correct data type and adding the necessary batch dimension. The output value will either be list of model predictions or a single prediction corresponding to the input. If the model has multiple inputs and outputs, then the input data must be one of: - Python list of numpy arrays. One numpy array per model input. The numpy arrays must only contain the values for one sample. The input numpy arrays do NOT need to have the batch dimension. In this case, the output values will also not have the batch dimension. - Dictionary of one or more numpy arrays. The dictionary key should be an integer corresponding to the model input, and the value should be a numpy array. The input numpy arrays do NOT need to have the batch dimension. In this case, the output values will also not have the batch dimension. Args: x: The input samples(s) as a numpy array or data generator. If x is a numpy array then it must have the same shape as the model input or it must be a vector (i.e. batch) of samples having the same shape as the model input. The data type must either be the same as the model input's OR it must be a float32, in which case the input sample will automatically be quantized using the model input's quantizing scaler/zeropoint. If x is a generator, then each iteration must return a tuple: batch_x, batch_y batch_x must be a vector (i.e. batch) of samples having the same shape as the model input batch_y is ignored. y_dtype: The return value's data type. By default, data type is None in which case the model output is directly returned. If y_dtype=np.float32 then the model output is de-quantized to float32 using the model's output quantization scaler/zeropoint (if necessary) Returns: Output of model inference, y. If x was a single sample, then y is a single result. Otherwise y is a vector (i.e. batch) of model results. If y_dtype is given, the y if automatically converted/de-quantized to the given dtype. """ifself._flatbuffer_dataisNone:raiseRuntimeError('Model not loaded')input0=self.get_input_tensor(0)input0_shape=input0.shapeifself.n_inputs>1:ifisinstance(x,(list,tuple)):ifnotall(isinstance(t,np.ndarray)fortinx):raiseValueError(''' For multi-input models, the input data must be a list of numpy arrays ''')x={index:valueforindex,valueinenumerate(x)}elifisinstance(x,dict):ifnotall(isinstance(k,int)andisinstance(v,np.ndarray)fork,vinx.items()):raiseValueError(''' For multi-input models, the input data must be a dictionary of numpy arrays with the keys corresponding to the model input index ''')else:raiseValueError(''' For multi-input models, the input data must be a list of numpy arrays or dictionary of numpy arrays with the keys corresponding to the model input index ''')# Set the input tensorshas_batch_dim=Trueforinput_index,x_iinx.items():ifinput_index==0:# Check if the input sample has the batch dimensioniflen(x_i.shape)==len(input0_shape[1:]):has_batch_dim=Falseself._allocate_tflite_interpreter(batch_size=1,interpreter_kwargs=kwargs.get('interpreter_kwargs',None))else:self._allocate_tflite_interpreter(batch_size=x_i.shape[0],interpreter_kwargs=kwargs.get('interpreter_kwargs',None))# Add the batch_size=1 if the input sample doesn't have a batch dimifnothas_batch_dim:x_i=np.expand_dims(x_i,axis=0)# If the input sample isn't the same as the model input dtype,# then we need to manually convert it first# NOTE: If the model input type is float32 then# quantization is done automatically inside the modelx_i=self.quantize_to_input_dtype(x_i,input_index=input_index)self._interpreter.set_tensor(self.get_input_tensor(input_index).index,x_i)# Execute the modelself._interpreter.invoke()# Get the model resultsy=[]fori,outpinenumerate(self.outputs):y_i=self._interpreter.get_tensor(outp.index)# If the input doesn't have a batch dim# then remove the dim from the outputifnothas_batch_dim:y_i=np.squeeze(y_i,axis=0)ify_dtype==np.float32:# Convert the output data type to float32 if necessaryy_i=self.dequantize_output_to_float32(y_i,output_index=i)y.append(y_i)returny# This expects either# [n_samples, input_shape...]# OR# [input_shape ...]ifisinstance(x,np.ndarray):is_single_sample=Falseiflen(x.shape)==len(input0_shape[1:]):is_single_sample=True# Add the batch dimension if we were only given a single samplex=np.expand_dims(x,axis=0)self._allocate_tflite_interpreter(batch_size=1,interpreter_kwargs=kwargs.get('interpreter_kwargs',None))else:self._allocate_tflite_interpreter(batch_size=x.shape[0],interpreter_kwargs=kwargs.get('interpreter_kwargs',None))# If the input sample isn't the same as the model input dtype,# then we need to manually convert it first# NOTE: If the model input type is float32 then# quantization is done automatically inside the modelx=self.quantize_to_input_dtype(x)# If the last dimension of the model's input shape is 1,# and the input data is missing this dimension# then automatically expand the dimensioniflen(input0_shape)!=len(x.shape)andinput0_shape[-1]==1:x=np.expand_dims(x,axis=-1)# Then set model input tensorself._interpreter.set_tensor(input0.index,x)# Execute the modelself._interpreter.invoke()# Get the model resultsy=self._interpreter.get_tensor(self.get_output_tensor(0).index)# Convert the output data type to float32 if necessary# NOTE: If the model output type is float32 then# de-quantization is done automatically inside the modelify_dtype==np.float32:y=self.dequantize_output_to_float32(y)# Remove the batch dimension if we were only given a single sampleifis_single_sample:y=np.squeeze(y,axis=0)returny# Else if we were given a data generatorelse:n_samples=0batch_results=[]forbatchinx:batch_x=batchifnotisinstance(batch,tuple)elsebatch[0]self._allocate_tflite_interpreter(batch_size=batch_x.shape[0])# If the input sample isn't the same as the model input dtype,# then we need to manually convert it firstbatch_x=self.quantize_to_input_dtype(batch_x)# If the last dimension of the model's input shape is 1,# and the batch data is missing this dimension# then automatically expand the dimensioniflen(input0_shape)!=len(batch_x.shape)andinput0_shape[-1]==1:batch_x=np.expand_dims(batch_x,axis=-1)# The set model input tensorself._interpreter.set_tensor(input0.index,batch_x)# Execute the modelself._interpreter.invoke()# Get the model resultsbatch_y=self._interpreter.get_tensor(self.get_output_tensor(0).index)ify_dtype==np.float32:# Convert the output data type to float32 if necessarybatch_y=self.dequantize_output_to_float32(batch_y)batch_results.append(batch_y)n_samples+=len(batch_y)# If the generator specifies a "max_samples" property# then break out of the loop once the specified number of samples have been processedtry:ifhasattr(x,'max_samples')andx.max_samples>0:ifn_samples>=x.max_samples:breakexcept:passiflen(batch_results)==0:raiseRuntimeError('No batch samples where generated by the data given data generator')batch_size=batch_results[0].shape[0]output_shape=batch_results[0].shape[1:]ifhasattr(x,'max_samples')andx.max_samples>0:n_samples=x.max_samplesy=np.zeros((n_samples,*output_shape),dtype=batch_y.dtype)forbatch_index,batchinenumerate(batch_results):forresult_index,resultinenumerate(batch):index=batch_index*batch_size+result_indexifindex>=n_samples:breaky[index,:]=resultreturny
[docs]defquantize_to_input_dtype(self,x:np.ndarray,input_index=0):"""Quantize the input sample(s) to the model's input dtype (if necessary)"""input_tensor=self.get_input_tensor(input_index)ifx.dtype==input_tensor.dtype:returnxifx.dtype!=np.float32:raiseRuntimeError('The sample input must be float32 or the same dtype as the model input')# Convert from float32 to the model input data typex=(x/input_tensor.quantization.scale[0])+input_tensor.quantization.zeropoint[0]returnx.astype(input_tensor.dtype)
[docs]defdequantize_output_to_float32(self,y:np.ndarray,output_index=0):"""De-quantize the model output to float32 (if necessary)"""ify.dtype==np.float32:returnyoutput_tensor=self.get_output_tensor(output_index)y=y.astype(np.float32)return(y-output_tensor.quantization.zeropoint[0])*output_tensor.quantization.scale[0]
def_allocate_tflite_interpreter(self,batch_size=1,interpreter_kwargs=None):ifself._interpreterisNoneorself._interpreter_batch_size!=batch_size:try:importtensorflowastfexceptModuleNotFoundErrorase:raiseModuleNotFoundError(f'You must first install the "tensorflow" Python package to run inference, err: {e}')# pylint: disable=raise-missing-frominterpreter_kwargs=interpreter_kwargsor{}self._interpreter_batch_size=batch_sizeself._interpreter=tf.lite.Interpreter(model_path=self._path,**interpreter_kwargs)input_indices=[]forinpinself.inputs:input_indices.append(inp.index)new_input_shape=(batch_size,*inp.shape[1:])self._interpreter.resize_tensor_input(inp.index,new_input_shape)foroutpinself.outputs:ifoutp.indexininput_indices:continuenew_output_shape=(batch_size,*outp.shape[1:])self._interpreter.resize_tensor_input(outp.index,new_output_shape)self._interpreter.allocate_tensors()def_load_model(self):try:self._model=_tflite_schema_fb.ModelT.InitFromObj(_tflite_schema_fb.Model.GetRootAsModel(self._flatbuffer_data,0))subgraph_count=len(self._model.subgraphs)exceptExceptionase:raiseRuntimeError(# pylint: disable=raise-missing-from'Failed to load .tflite model flatbuffer.\n''Ensure you have provided a valid .tflite model (i.e. ensure the binary data has not been corrupted)\n'f'Error details: {e}')schema_version=self._model.versionifschema_version!=3:raiseRuntimeError('TF-Lite schema v3 is only supported')ifself._selected_model_subgraph_index==-1orself._selected_model_subgraph_index>=subgraph_count:self._selected_model_subgraph_index=0self._subgraphs=[]forfb_subgraphinself._model.subgraphs:subgraph=_TfliteSubgraph()self._subgraphs.append(subgraph)fori,fb_tensorinenumerate(fb_subgraph.tensors):tensor=TfliteTensor(i,self,fb_tensor)subgraph.tensors.append(tensor)fori,operatorinenumerate(fb_subgraph.operators):layer=TfliteLayer.from_flatbuffer(i,self,operator)subgraph.layers.append(layer)
Important: We use cookies only for functional and traffic analytics.
We DO NOT use cookies for any marketing purposes. By using our site you acknowledge you have read and understood our Cookie Policy.