Mean or max pooling with masking support in Keras
Asked Answered
M

4

8
...
print('Build model...')
model = Sequential()
model.add(Embedding(max_features, 128))
model.add(LSTM(size, return_sequences=True, dropout_W=0.2 dropout_U=0.2)) 
model.add(GlobalAveragePooling1D())
model.add(Dense(1))
model.add(Activation('sigmoid'))
....

I need to be able to take the mean or max of the vectors for all time steps in a sample after LSTM layer before giving this mean or max vector to the dense layer in Keras.

I think timedistributedmerge was able to do this but it was deprecated. Using return_sequences=True I can obtain the vectors for all time steps in a sample after the LSTM layer. However, GlobalAveragePooling1D() is not compatible with masking and it considers all time steps whereas I need only the non-masked time steps.

I saw posts recommending the Lambda layer but these also do not take masking into account. Any help would be appreciated.

Means answered 15/9, 2016 at 12:16 Comment(0)
R
4

Jacoxu's answer is right. But if you are using a tensorflow backend for keras, the Tensor type doesn't support dimshuffle function, try this instead.

def call(self, x, mask=None):
    if mask is not None:
        # mask (batch, time)
        mask = K.cast(mask, K.floatx())
        # mask (batch, x_dim, time)
        mask = K.repeat(mask, x.shape[-1])
        # mask (batch, time, x_dim)
        mask = tf.transpose(mask, [0,2,1])
        x = x * mask
    return K.sum(x, axis=1) / K.sum(mask, axis=1)
Rive answered 17/8, 2017 at 8:52 Comment(0)
M
3

Since average pooling is only doing a mean over one axis, you just need to correct the number of elements in the mean since loss masking is handled at the end, not here. You can do this probably with something like this:

class GlobalAveragePooling1DMasked(GlobalAveragePooling1D):
    def call(self, x, mask=None):
        if mask != None:
            return K.sum(x, axis=1) / K.sum(mask, axis=1)
        else:
            return super().call(x)
Mucoprotein answered 16/9, 2016 at 14:35 Comment(4)
Note that you can't be certain in Keras that masked values in x are equal to zero! Therefore, this implementation would give wrong results.Misguided
The mask itself, introduced by the Masking and Embedding layer, is binary. Of course you can always have layers that implement compute_mask differently, but this is not happening in keras itself as far as I can see.Mucoprotein
Yes, the mask is binary. That is not what I'm saying :)Misguided
But that's the only assumption I make in the example above. I don't understand what you are getting at, sorry. Edit: Nevermind, it's the sum operation of matrix x that makes the assumption of masked values being zero. You are right. This could probably be fixed by indexing x by mask but I have to test this first.Mucoprotein
K
3

In order to make the masked values in x be equal to zero, you can do this:

class MeanPool(Layer):
def __init__(self, **kwargs):
    self.supports_masking = True
    super(MeanPool, self).__init__(**kwargs)

def compute_mask(self, input, input_mask=None):
    # do not pass the mask to the next layers
    return None

def call(self, x, mask=None):
    if mask is not None:
        # mask (batch, time)
        mask = K.cast(mask, K.floatx())
        # mask (batch, time, 'x')
        mask = mask.dimshuffle(0, 1, 'x')
        # to make the masked values in x be equal to zero
        x = x * mask
    return K.sum(x, axis=1) / K.sum(mask, axis=1)

def get_output_shape_for(self, input_shape):
    # remove temporal dimension
    return input_shape[0], input_shape[2]
Kano answered 22/3, 2017 at 5:54 Comment(0)
T
3

This is how I did it on Keras 2 (borrowing from all of the answers, and fixing the dimensions):

class MeanPool(Layer):
  def __init__(self, **kwargs):
      self.supports_masking = True
      super(MeanPool, self).__init__(**kwargs)

  def compute_mask(self, input, input_mask=None):
      # do not pass the mask to the next layers
      return None

  def call(self, x, mask=None):
      if mask is not None:
          # mask (batch, time)
          mask = K.cast(mask, K.floatx())
          # mask (batch, x_dim, time)
          mask = K.repeat(mask, x.shape[-1])
          # mask (batch, time, x_dim)
          mask = tf.transpose(mask, [0,2,1])
          x = x * mask
      return K.sum(x, axis=1) / K.sum(mask, axis=1)

  def compute_output_shape(self, input_shape):
      # remove temporal dimension
      return (input_shape[0], input_shape[2])
Terrieterrier answered 2/5, 2018 at 23:10 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.