Tensorflow LSTM character by character sequence prediction
Asked Answered
A

1

8

I am attempting to replicate the character level language modeling demonstrated in the excellent article http://karpathy.github.io/2015/05/21/rnn-effectiveness/ using Tensorflow.

So far my attempts have failed. My network typically outputs a single character after processing 800 or so characters. I believe I have fundamentally misunderstood the way tensor flow has implemented LSTMs, and perhaps rnns in general. I am finding the documentation to be difficult to follow.

Here is the essence of my code:

Graph definition

idata = tf.placeholder(tf.int32,[None,1])   #input byte, use value 256 for start and end of file
odata = tf.placeholder(tf.int32,[None,1])    #target output byte, ie, next byte in sequence..
source =  tf.to_float(tf.one_hot(idata,257)) #input byte as 1-hot float
target = tf.to_float(tf.one_hot(odata,257))  #target output as 1-hot float

with tf.variable_scope("lstm01"):
    cell1 = tf.nn.rnn_cell.BasicLSTMCell(257)
    val1, state1 = tf.nn.dynamic_rnn(cell1, source, dtype=tf.float32)

output = val1

Loss Calculation

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(output, target))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)  
output_am = tf.argmax(output,2)
target_am = tf.argmax(target,2)
correct_prediction = tf.equal(output_am, target_am)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

Training

for i in range(0, source_data.size-1, batch_size):
    start = i
    stop = i + batch_size
    i_data = source_data[start:stop].reshape([-1,1])
    o_data = source_data[start+1:stop+1].reshape([-1,1])

    train_step.run(feed_dict={idata: i_data, odata: o_data})

    if i%(report_interval*batch_size) == 0:
        batch_out, fa = sess.run([output_am, accuracy], feed_dict={idata: i_data, odata: o_data, keep_prob: 1.0})

        print("step %d, training accuracy %s"%(i, str(fa)))
        print("i_data sample: %s"%str(squeeze(i_data)))
        print("o_data sample: %s"%str(squeeze(o_data)))
        print("batch sample: %s"%str(squeeze(batch_out)))

Output, using 1MB Shakespere file to train

step 0, training accuracy 0.0
i_data sample: [ 256.   70.  105.  114.  115.  116.   32.   67.  105.  116.]
o_data sample: [  70.  105.  114.  115.  116.   32.   67.  105.  116.  105.]
batch sample: [254  18 151  64  51 199  83 174 151 199]

step 400, training accuracy 0.2
i_data sample: [  32.   98.  101.   32.  100.  111.  110.  101.   58.   32.]
o_data sample: [  98.  101.   32.  100.  111.  110.  101.   58.   32.   97.]
batch sample: [ 32 101  32  32  32  32  10  32 101  32]

step 800, training accuracy 0.0
i_data sample: [ 112.   97.  114.  116.  105.   99.  117.  108.   97.  114.]
o_data sample: [  97.  114.  116.  105.   99.  117.  108.   97.  114.  105.]
batch sample: [101 101 101  32 101 101  32 101 101 101]

step 1200, training accuracy 0.1
i_data sample: [  63.   10.   10.   70.  105.  114.  115.  116.   32.   67.]
o_data sample: [  10.   10.   70.  105.  114.  115.  116.   32.   67.  105.]
batch sample: [ 32  32  32 101  32  32  32  32  32  32]

step 1600, training accuracy 0.2
i_data sample: [  32.  116.  105.  108.  108.   32.  116.  104.  101.   32.]
o_data sample: [ 116.  105.  108.  108.   32.  116.  104.  101.   32.   97.]
batch sample: [32 32 32 32 32 32 32 32 32 32]

This is clearly incorrect.

I think I am getting confused by the difference between 'batches' and 'sequences', and as to whether or not the state of the LSTM is preserved between what I call 'batches' (ie, sub-sequences)

I'm getting the impression that I've trained it using 'batches' of sequences of length 1, and that between each batch, state data is discarded. Consequently it is simply finding the most commonly occurring symbol.

Can anyone confirm this, or otherwise correct my mistake, and give some indication of how I should go about the task of character by character prediction using very long training sequences?

Many Thanks.

Australopithecus answered 12/1, 2017 at 19:25 Comment(0)
M
2

So your idata should have a shape of: [batch_size, maximum_sequence_length, 257]. (If not all sequences in a batch have the same length you need to pad as necessary, and be careful when computing losses that this is done only over non-padded values.)

The dynamic_rnn steps through your input by time for you. So, you only need to loop over batches.

Since, your second dimension of idata is 1 you are right that your effective sequence length is 1.

For a language model not character-based but using word embeddings take a look at this tutorial.

Other notes:

  • If you want to experiment with different number of units in the LSTM - consider adding a linear layer on top of the output to project each output (for batch entry i at time t) down to 257 which is the number of classes of your target.

  • No need to do a one-hot encoding of the target. Take a look at sparse_softmax_cross_entropy.

Mcgrath answered 13/1, 2017 at 1:25 Comment(5)
I'm considering my training data as a single very long sequence. I expect this to cause memory issues if I attempt to put multiple megabyte sequences through the network. Do I simply read out the state variable and refeed it into the next execution of the graph as the initial_state of the dynamic_rnn op?Australopithecus
retaining the state of RNNs is addressed here: #38241910 and here: #37969565 and here: github.com/tensorflow/tensorflow/issues/3476 and also here: github.com/tensorflow/tensorflow/issues/2838Australopithecus
I would echo danijar suggestion to use the static_state_saving_rnn.Mcgrath
I looked at the api for state_saving_rnn. It requires a state_saver object, which I'm not sure how to provide (it seems the only reference to state saver in the docs is for saving to disk). So instead I used user1506145's solution, which it's self wasn't complete, because you can't pass an LSTMStateTuple into a feed. It took me a little while to figure out I had to do this: for n in range(number_of_layers): init_state[n,0] = state[n].c, init_state[n,1] = state[n].h to copy the data from the LSTMStateTuple into the original list.Australopithecus
Reinserting the state data doesn't seem to work either. I'll post another question when I've figured out what the question is.Australopithecus

© 2022 - 2024 — McMap. All rights reserved.