"""TENet - Temporal efficient neural network***********************************************This is a keyword spotting architecture with temporal and depthwise convolutions. Li, Wei, Qin: Small-Footprint Keyword Spotting with Multi-Scale Temporal Convolution, `https://arxiv.org/pdf/2010.09960.pdf`_"""fromtypingimportUnion,List,Tupleimportnumpyasnpimporttensorflowastf
[docs]defTENet(input_shape:Union[List[int],Tuple[int]],classes:int,channels:int=32,blocks:int=3,block_depth:int=4,scales:List[int]=[9],channel_increase:float=0.0,include_head=True,return_model=True,input_layer:tf.keras.layers.Input=None,dropout:float=0.1,*args,**kwargs,)->Union[tf.keras.Model,tf.keras.layers.Layer]:"""Temporal efficient neural network (TENet) A network for processing spectrogram data using temporal and depthwise convolutions. The network treats the [T, F] spectrogram as a timeseries shaped [T, 1, F]. .. note:: When building the model, make sure that the input shape is concrete, i.e. explicitly reshape the samples to [T, 1, F] in the preprocessing pipeline. .. seealso:: * https://arxiv.org/pdf/2010.09960.pdf Args: classes: Number of classes the network is built to categorize channels: Base number of channels in the network blocks: Number of (StridedIBB -> IBB -> ...) blocks in the networks block_depth: Number of IBBs inside each (StridedIBB -> IBB -> ...) block, including the strided IBB scales: The multitemporal convolution filter widths. Should be odd numbers >= 3. channel_increase: If nonzero, the network increases the channel size each time there is a strided IBB block. The increase (each time) is given by `channels * channel_increase`. include_head: If true, add a classifier head to the model return_model: If true, return a Keras model input_layer: Use the given layer as the input to the model. If None, the create a layer using the given input shape dropout: The dropout to use when include_head=True"""count_layers=blocks*block_depthifinput_layerisNone:ifisinstance(input_shape,(tuple,list)):input_shape=tf.TensorShape(input_shape)ifnotisinstance(input_shape,tf.TensorShape):raiseValueError("Invalid input_shape: Expected only one input")ifnotinput_shape.is_compatible_with((None,1,None)):raiseValueError(f"Invalid input_shape: Expected (T, 1, C) but received {input_shape}")input_layer=tf.keras.layers.Input(shape=input_shape)x=input_layerx=tf.keras.layers.Conv2D(channels,(3,1),padding="same",use_bias=True,)(x)x=tf.keras.layers.BatchNormalization()(x)forlayerinrange(count_layers):x=InvertedBottleneckBlock(x,channels=int(channels*(1+channel_increase*(1+layer//block_depth))),stride=2if((layer%block_depth)==0)else1,scales=scales,)ifinclude_head:#x = tf.keras.layers.GlobalAveragePooling2D()(x)x=tf.keras.layers.AveragePooling2D(pool_size=(x.shape[1],1))(x)x=tf.keras.layers.Flatten()(x)x=tf.keras.layers.Dropout(dropout)(x)x=tf.keras.layers.Dense(classes,activation=tf.keras.activations.softmax)(x)ifreturn_model:returntf.keras.models.Model(input_layer,x,name="TENet")else:returnx
defTENet12(input_shape,classes:int,mtconv:bool=False,**kwargs)->tf.keras.Model:returnTENet(input_shape=input_shape,classes=classes,channels=kwargs.pop('channels',32),blocks=kwargs.pop('blocks',3),block_depth=kwargs.pop('block_depth',4),scales=kwargs.pop('scales',[9,7,5,3]ifmtconvelse[9]),**kwargs,)defTENet6(input_shape,classes:int,mtconv:bool=False,**kwargs)->tf.keras.Model:returnTENet(input_shape=input_shape,classes=classes,channels=kwargs.pop('channels',32),blocks=kwargs.pop('blocks',3),block_depth=kwargs.pop('block_depth',2),scales=kwargs.pop('scales',[9,7,5,3]ifmtconvelse[9]),**kwargs,)defTENet12Narrow(input_shape,classes:int,mtconv:bool=False,**kwargs)->tf.keras.Model:returnTENet(input_shape=input_shape,classes=classes,channels=kwargs.pop('channels',16),blocks=kwargs.pop('blocks',3),block_depth=kwargs.pop('block_depth',4),scales=kwargs.pop('scales',[9,7,5,3]ifmtconvelse[9]),**kwargs,)defTENet6Narrow(input_shape,classes:int,mtconv:bool=False,**kwargs)->tf.keras.Model:returnTENet(input_shape=input_shape,classes=classes,channels=kwargs.pop('channels',16),blocks=kwargs.pop('blocks',3),block_depth=kwargs.pop('block_depth',2),scales=kwargs.pop('scales',[9,7,5,3]ifmtconvelse[9]),**kwargs,)defHFTENet12(input_shape,classes:int,mtconv:bool=False,**kwargs)->tf.keras.Model:"Custom TENet variant with channels that increase as the time axis shrinks."returnTENet(input_shape=input_shape,classes=classes,channels=kwargs.pop('channels',32),blocks=kwargs.pop('blocks',3),block_depth=kwargs.pop('block_depth',4),scales=kwargs.pop('scales',[9,7,5,3]ifmtconvelse[9]),channel_increase=0.125,**kwargs,)classMultiScaleTemporalConvolution(tf.keras.layers.Layer):"""Convolution which combines results form several filter widths with padding. Convolution is a linearly separable operation, so these different filters can be superimposed during inference. This makes it possible to flatten the different filters into one. Call `.fuse` after training to flatten the filters into one. """_layer_counter:int=0def__init__(self,stride:int=1,scales:Union[int,List[int],Tuple[int,...]]=None,**kwargs,):scales=scalesor[3,5,7,9]ifisinstance(scales,int):scales=[scales]ifnotisinstance(scales,(list,tuple)):raiseTypeError(f"Expected scales to be an int or a tuple/list of ints, but received {type(scales)}")iflen(scales)<1:raiseValueError("Expected atleast one temporal scale, received 0")ifany(scale<3forscaleinscales):raiseValueError(f"Expected scales >= 3, received {scales}")ifany((scale%2)!=1forscaleinscales):raiseValueError(f"Expected odd scales, received {scales}")self.scales=list(scales)self.stride=strideself.temporal_convolutions:List[tf.keras.layers.DepthwiseConv2D]=[]self._input_shape=Nonesuper().__init__(**kwargs)defget_config(self):config=super(MultiScaleTemporalConvolution,self).get_config()config.update({"stride":self.stride})config.update({"scales":self.scales})returnconfigdefbuild(self,input_shape:Union[tf.TensorShape,List[tf.TensorShape]]):self._input_shape=input_shapeself.temporal_convolutions=[tf.keras.layers.DepthwiseConv2D(kernel_size=(scale,1),strides=self.stride,padding="same",use_bias=False,name=f"mtconv{scale}-{i}-{self._layer_counter}",input_shape=input_shape[1:],)fori,scaleinenumerate(self.scales)]self._layer_counter+=1forbranchinself.temporal_convolutions:branch.build(input_shape)returnsuper().build(input_shape)deffuse(self):"""Fuse convolutions in-place"""fused_convolution=self.fused()scale=max(self.scales)self.temporal_convolutions=[fused_convolution]self.scales=[scale]super().build(self._input_shape)deffused(self)->tf.keras.layers.DepthwiseConv2D:"""Returns a new depthwise conv, created by fusuing the filters of each temporal convolution in this MTConv block."""argmax_scale:int=tf.argmax(self.scales)max_scale:int=self.scales[argmax_scale]fused_convolution=tf.keras.layers.DepthwiseConv2D(kernel_size=(max_scale,1),strides=self.stride,padding="same",use_bias=False,name=f"mtconv_superimposed-{self._layer_counter}",)self._layer_counter+=1fused_convolution.build(self._input_shape)fused_convolution.set_weights(self.temporal_convolutions[argmax_scale].get_weights())fori,(scale,branch)inenumerate(zip(self.scales,self.temporal_convolutions)):ifi==argmax_scale:continueweights:np.ndarray=branch.get_weights()[0]assertweights.shape[0]==scale,f"Unexpected weight shape {weights.shape}"pad=(max_scale-scale)//2padded_weights=np.pad(weights,pad_width=np.array([(pad,pad),(0,0),(0,0),(0,0)]),mode="constant",)updated_weights=fused_convolution.get_weights()[0]+padded_weightsfused_convolution.set_weights([updated_weights])returnfused_convolutiondefcompute_output_shape(self,input_shape):returnself.temporal_convolutions[0].compute_output_shape(input_shape)defcall(self,inputs:tf.Tensor,*args,**kwargs)->tf.Tensor:branches=[conv(inputs)forconvinself.temporal_convolutions]iflen(branches)==1:returnbranches[0]# type: ignoreoutput=tf.stack(branches,axis=0)returntf.reduce_sum(output,axis=0)defInvertedBottleneckBlock(x:tf.keras.layers.Layer,channels:int,stride:int,expansion_ratio:Union[float,int]=3,scales:List[int]=None,)->tf.keras.layers.Layer:"""Inverted bottleneck with depthwise separable temporal convolution and a residual connection."""input_shape=x.shapeifnotlen(input_shape)==4andinput_shape[-2]==1:raiseValueError(f"Invalid input_shape: Exected (N, T, 1, C) but received {input_shape}")ifstride==1andnotinput_shape[-1]==channels:raiseValueError(f"Channel change is only supported for strided layers. "f"Expected input_shape channels ({input_shape[-1]}) "f"to match the bottleneck output channels ({channels}")layer_id=globals().get('InvertedBottleneckBlock_layer_id',0)globals()['InvertedBottleneckBlock_layer_id']=layer_id+1expansion_channels=int(channels*expansion_ratio)scales=scalesor[9]layer_input=xx=tf.keras.layers.Conv2D(filters=expansion_channels,kernel_size=(1,1),strides=1,use_bias=False,name=f"pointwise_expand_conv-{layer_id}",)(x)x=tf.keras.layers.BatchNormalization()(x)x=tf.keras.layers.ReLU()(x)x=MultiScaleTemporalConvolution(stride=stride,scales=scales,name=f"mtconv-{layer_id}",)(x)x=tf.keras.layers.BatchNormalization()(x)x=tf.keras.layers.ReLU()(x)x=tf.keras.layers.Conv2D(filters=channels,kernel_size=(1,1),strides=1,use_bias=False,name=f"pointwise_contract_conv-{layer_id}",)(x)x=tf.keras.layers.BatchNormalization()(x)bottleneck_temporal_depth_separable_convolution=xifstride==1:x=tf.keras.layers.Add()([bottleneck_temporal_depth_separable_convolution,layer_input])else:residuals=tf.keras.layers.Conv2D(filters=channels,kernel_size=(1,1),strides=stride,padding="same",use_bias=False,name=f"strided_residual-{layer_id}",)(layer_input)residuals=tf.keras.layers.BatchNormalization()(residuals)residuals=tf.keras.layers.ReLU()(residuals)x=tf.keras.layers.Add()([bottleneck_temporal_depth_separable_convolution,residuals])x=tf.keras.layers.ReLU()(x)returnx
Important: We use cookies only for functional and traffic analytics.
We DO NOT use cookies for any marketing purposes. By using our site you acknowledge you have read and understood our Cookie Policy.