How to implement Tensorflow batch normalization in LSTM
Asked Answered
W

2

21

My current LSTM network looks like this.

rnn_cell = tf.contrib.rnn.BasicRNNCell(num_units=CELL_SIZE)
init_s = rnn_cell.zero_state(batch_size=1, dtype=tf.float32)  # very first hidden state
outputs, final_s = tf.nn.dynamic_rnn(
    rnn_cell,              # cell you have chosen
    tf_x,                  # input
    initial_state=init_s,  # the initial hidden state
    time_major=False,      # False: (batch, time step, input); True: (time step, batch, input)
)

# reshape 3D output to 2D for fully connected layer
outs2D = tf.reshape(outputs, [-1, CELL_SIZE])
net_outs2D = tf.layers.dense(outs2D, INPUT_SIZE)

# reshape back to 3D
outs = tf.reshape(net_outs2D, [-1, TIME_STEP, INPUT_SIZE])

Usually, I apply tf.layers.batch_normalization as batch normalization. But I am not sure if this works in a LSTM network.

b1 = tf.layers.batch_normalization(outputs, momentum=0.4, training=True)
d1 = tf.layers.dropout(b1, rate=0.4, training=True)

# reshape 3D output to 2D for fully connected layer
outs2D = tf.reshape(d1, [-1, CELL_SIZE])                       
net_outs2D = tf.layers.dense(outs2D, INPUT_SIZE)

# reshape back to 3D
outs = tf.reshape(net_outs2D, [-1, TIME_STEP, INPUT_SIZE])
Wigging answered 24/10, 2017 at 16:13 Comment(2)
github.com/tensorflow/tensorflow/issues/1736Analyst
Following the link above there has been an implementation of bn for lstm that isn’t pulled into master yet. github.com/tensorflow/tensorflow/pull/14106/commitsCockoftherock
L
4

If you want to use batch norm for RNN (LSTM or GRU), you can check out this implementation , or read the full description from blog post.

However, the layer-normalization has more advantage than batch norm in sequence data. Specifically, "the effect of batch normalization is dependent on the mini-batch size and it is not obvious how to apply it to recurrent networks" (from the paper Ba, et al. Layer normalization).

For layer normalization, it normalizes the summed inputs within each layer. You can check out the implementation of layer-normalization for GRU cell:

Landan answered 14/8, 2019 at 7:13 Comment(0)
F
0

Based on this paper: "Layer Normalization" - Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton

Tensorflow now comes with the tf.contrib.rnn.LayerNormBasicLSTMCell a LSTM unit with layer normalization and recurrent dropout.

Find the documentation here.

Freezer answered 21/12, 2018 at 19:44 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.