[docs]classTfliteTensor(_tflite_schema_fb.TensorT):"""Wrapper for TFLite flatbuffer tensor"""
[docs]def__init__(self,index:int=-1,model:TfliteModel=None,fb_tensor:_tflite_schema_fb.TensorT=None):_tflite_schema_fb.TensorT.__init__(self)iffb_tensorisnotNone:forxinvars(fb_tensor):setattr(self,x,getattr(fb_tensor,x))else:self.shape=Noneself.quantization=Noneself._model=modelself._index=int(index)self.name=''ifnotself.nameelseself.name.decode("utf-8")ifmodelisnotNoneandfb_tensorisnotNone:buffer=model.flatbuffer_model.buffers[fb_tensor.buffer]ifbuffer.dataisnotNone:data_bytes=buffer.data.tobytes()ifisinstance(buffer.data,np.ndarray)elsebytes(buffer.data)ifhasattr(_tflite_schema_fb.TensorType,'INT4')andself.type==_tflite_schema_fb.TensorType.INT4:# NumPy does not support int4 so we have to expand to int8# See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/internal/portable_tensor_utils.cc# UnpackDenseInt4IntoInt8()n_elements=self.shape.flat_sizeraw_data_array=np.empty((n_elements,),dtype=np.int8)sign_bit_mask=1<<(4-1)defsign_extend_4bits(value):return(value&(sign_bit_mask-1))-(value&sign_bit_mask)foriinrange(n_elements//2):v=data_bytes[i]lower=sign_extend_4bits(v&0x0F)upper=sign_extend_4bits((v&0xF0)>>4)raw_data_array[i*2+0]=lowerraw_data_array[i*2+1]=upper# If the buffer size is odd, extract the final lower nibble.ifn_elements%2!=0:v=data_bytes[n_elements//2]lower=sign_extend_4bits((v&0xF0)>>4)raw_data_array[-1]=lowerelse:raw_data_array=np.frombuffer(data_bytes,dtype=self.dtype)ifraw_data_array.size==self.shape.flat_sizeandlen(self.shape)>1:self._data=raw_data_array.reshape(self.shape)else:self._data=raw_data_arrayifnothasattr(self,'_data'):s=self.shapeiflen(self.shape)>1else(0,)self._data=np.zeros(s,dtype=self.dtype)
@propertydefindex(self)->int:"""Index of tensor in .tflite subgraph.tensors list"""returnself._index@propertydefdtype(self)->np.dtype:"""Tensor data type"""returntflite_to_numpy_dtype(self.type)@dtype.setterdefdtype(self,v):self.type=numpy_to_tflite_type(v)@propertydefdtype_str(self)->str:"""Tensor data type as a string"""returnself.dtype.__name__.replace('numpy.','')@propertydefshape(self)->TfliteShape:"""The tensor shape"""returnTfliteShape(self._shape)@shape.setterdefshape(self,v):self._shape=(0,)ifvisNoneelsev@propertydefquantization(self)->TfliteQuantization:"""Data quantization information"""returnself._quantization@quantization.setterdefquantization(self,v:TfliteQuantization):self._quantization=TfliteQuantization(v)@propertydefis_variable(self)->bool:"""True if this tensor is populated at runtime and its state persists between inferences"""returnself.isVariable@propertydefsize_bytes(self)->int:"""The number of bytes required to hold the data for this tensor"""ifself._dataisNone:return0returnself.data.nbytes@propertydefdata(self)->np.ndarray:"""Tensor data"""returnself._data@data.setterdefdata(self,v:Union[np.ndarray,bytes]):"""Tensor data"""ifisinstance(v,np.ndarray):ifv.dtype!=self.dtype:raiseValueError(f'Data type must be {self.dtype}')ifv.size!=self.shape.flat_size:raiseValueError(f'Number of elements in data must be {self.shape.flat_size}')iflen(v.shape)==1:v=v.reshape(self.shape)self._data=velse:self._data=np.frombuffer(v,dtype=np.uint8)ifhasattr(self,'_model'):buffer=_tflite_schema_fb.BufferT()data_bytes=self._data.tobytes()ifhasattr(_tflite_schema_fb.TensorType,'INT4')andself.type==_tflite_schema_fb.TensorType.INT4andisinstance(v,np.ndarray):# NumPy does not support int4 so we have to pack the two int8 values into 1 byte# See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/internal/portable_tensor_utils.ccdata_bytes_len=len(data_bytes)packed_data=bytearray()foriinrange(data_bytes_len//2):lower=data_bytes[i*2+0]upper=data_bytes[i*2+1]assertlower.bit_length()<=4assertupper.bit_length()<=4packed_value=(lower&0x0F)|((upper&0x0F)<<4)packed_data.append(packed_value)ifdata_bytes_len%2!=0:lower=data_bytes[-1]assertlower.bit_length()<=4packed_value=(lower&0x0F)packed_data.append(packed_value)data_bytes=packed_databuffer.data=np.frombuffer(data_bytes,dtype=np.uint8)self._model.flatbuffer_model.buffers[self.buffer]=bufferself._model.regenerate_flatbuffer()@propertydefmodel(self)->TfliteModel:"""Reference to associated TfliteModel"""returnself._model
[docs]defshape_dtype_str(self,include_batch=False)->str:"""Return the shape and data-type of this tensor as a string: <dim0>x<dim1>x... (<dtype>)"""shape=self.shapeifnotinclude_batchandlen(shape)>1:shape=shape[1:]returnf'{"x".join(f"{d}"fordinshape)} ({self.dtype_str})'
def__str__(self):return'x'.join(f'{x}'forxinself)@propertydefflat_size(self)->int:"""Total number of elements or flatten size"""n=1forxinself:n*=xreturnn
[docs]classTfliteQuantization(_tflite_schema_fb.QuantizationParametersT):"""Wrapper for tensor quantization Refer to `Quantization Specification <https://www.tensorflow.org/lite/performance/quantization_spec>`_ for more details. """
@propertydefscale(self)->List[float]:"""Quantization scalers as list of float values"""returnself.__dict__.get('scale')@scale.setterdefscale(self,v:List[float]):v=[]ifvisNoneelse[float(x)forxinv]self.__dict__['scale']=v@propertydefzeropoint(self)->List[int]:"""Quantization zero points as list of integers"""returnself.__dict__.get('zeroPoint')@zeropoint.setterdefzeropoint(self,v):v=[]ifvisNoneelse[int(x)forxinv]self.__dict__['zeroPoint']=v@propertydefquantization_dimension(self)->int:"""Quantization dimension"""returnself.__dict__.get('quantizedDimension',None)@quantization_dimension.setterdefquantization_dimension(self,v:int):self.__dict__['quantizedDimension']=v@propertydefn_channels(self)->int:"""Number of channels. This is the number of elements in :py:attr:`~scale` and :py:attr:`~zeropoint`"""returnlen(self.scale)
deftflite_to_numpy_dtype(tflite_type:_tflite_schema_fb.TensorType)->np.dtype:"""Convert a tflite schema dtype to numpy dtype"""iftflite_type==_tflite_schema_fb.TensorType.FLOAT32:returnnp.float32eliftflite_type==_tflite_schema_fb.TensorType.FLOAT16:returnnp.float16eliftflite_type==_tflite_schema_fb.TensorType.INT32:returnnp.int32eliftflite_type==_tflite_schema_fb.TensorType.UINT8:returnnp.uint8eliftflite_type==_tflite_schema_fb.TensorType.INT64:returnnp.int64eliftflite_type==_tflite_schema_fb.TensorType.INT16:returnnp.int16eliftflite_type==_tflite_schema_fb.TensorType.INT8:returnnp.int8elifhasattr(_tflite_schema_fb.TensorType,'INT4')andtflite_type==_tflite_schema_fb.TensorType.INT4:# Numpy does not support 4-bit, so we have to use int8returnnp.int8eliftflite_type==_tflite_schema_fb.TensorType.BOOL:returnnp.bool8else:raiseValueError(f'Unsupported .tflite tensor data type: {tflite_type}')defnumpy_to_tflite_type(dtype:np.dtype)->_tflite_schema_fb.TensorType:"""Convert numpy dtype to tflite schema dtype"""ifdtype==np.float32:return_tflite_schema_fb.TensorType.FLOAT32elifdtype==np.float16:return_tflite_schema_fb.TensorType.FLOAT16elifdtype==np.int32:return_tflite_schema_fb.TensorType.INT32elifdtype==np.uint8:return_tflite_schema_fb.TensorType.UINT8elifdtype==np.int64:return_tflite_schema_fb.TensorType.INT64elifdtype==np.int16:return_tflite_schema_fb.TensorType.INT16elifdtype==np.int8:return_tflite_schema_fb.TensorType.INT8elifdtype==np.bool8:return_tflite_schema_fb.TensorType.BOOLelse:raiseValueError(f'Unsupported .tflite tensor data type: {dtype}')
Important: We use cookies only for functional and traffic analytics.
We DO NOT use cookies for any marketing purposes. By using our site you acknowledge you have read and understood our Cookie Policy.