Seq2Seq model learns to only output EOS token (<\s>) after a few iterations
Asked Answered
B

3

9

I am creating a chatbot trained on Cornell Movie Dialogs Corpus using NMT.

I am basing my code in part from https://github.com/bshao001/ChatLearner and https://github.com/chiphuyen/stanford-tensorflow-tutorials/tree/master/assignments/chatbot

During training, I print a random output answer fed to the decoder from the batch and the corresponding answer that my model predicts to observe the learning progress.

My issue: After only about 4 iterations of training, the model learns to output the EOS token (<\s>) for every timestep. It always outputs that as its response (determined using argmax of logits) even as training continues. Once in a while, rarely, the model outputs series of periods as its answer.

I also print the top 10 logit values during training (not just the argmax) to see if maybe the correct word is somewhere in there, but it seems to be predicting the most common words in the vocab (e.g i, you, ?, .). Even these top 10 words don't change much during training.

I have made sure to correctly count input sequence lengths for encoder and decoder, and added SOS (<s>) and EOS (also used for padding) tokens accordingly. I also perform masking in the loss calculation.

Here is a sample output:

Training iteration 1:

Decoder Input: <s> sure . sure . <\s> <\s> <\s> <\s> <\s> <\s> <\s> 
<\s> <\s>
Predicted Answer: wildlife bakery mentality mentality administration 
administration winston winston winston magazines magazines magazines 
magazines

...

Training iteration 4:

Decoder Input: <s> i guess i had it coming . let us call it settled . 
<\s> <\s> <\s> <\s> <\s>
Predicted Answer: <\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s> 
<\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s>


After a few more iterations, it settles on only predicting EOS (and rarely some periods)

I am not sure what could be causing this issue and have been stuck on this for a while. Any help would be greatly appreciated!

Update: I let it train for over a hundred thousand iterations and it still only outputs EOS (and occasional periods). The training loss also does not decrease after a few iteration (it remains at around 47 from the beginning)

Boult answered 28/9, 2018 at 21:34 Comment(4)
What constitutes an "iteration"? Is it a minibatch? An epoch? Either way, this behaviour does not surprise me that much. When I train RNNs they usually go through a phase, early on during training, where they repeatedly output the same symbol. The solution may simply be that you need to train the model for longer. If the behaviour persists after training for many epochs then something may be wrong.Sogdian
An iteration in this case is just applying gradient descent to a single random batch. I have let it train for a a few thousand iterations and the predicted output is always EOS. Even when I inspect the top 10 logits as training progresses (not just the max used for prediction output), it seems to always be the highest frequency (most common) words in the vocab that have highest logits. I am not sure what could be causing this problem as I based my code off the NMT tutorialBoult
@Sogdian Update: I let it train for over a hundred thousand iterations and it still only outputs EOS (and occasional periods). The training loss also does not decrease after the first iteration (it remains at around 47)Boult
Noel, Did you ever find a solution to this? I am facing the same issue, and I followed the advice from @Sogdian too. My model gets around 98% accuracy and then drops down to 5% accuracy and then climbs back up to 20%, but it still is only predicting end tokens. I have no idea why the accuracy is even changing when it only outputs the argmax which is always the end tokenEulalia
C
0

recently I also work on seq2seq model. I have encountered your problem before, in my case, I solve it by changing the loss function.

You said you use mask, so I guess you use tf.contrib.seq2seq.sequence_loss as I did.

I changed to tf.nn.softmax_cross_entropy_with_logits, and it works normally (and higher computation cost).

(Edit 05/10/2018. Pardon me, I need to edit since I found there is an egregious mistake in my code)

tf.contrib.seq2seq.sequence_loss can work really well, if the shape of logits ,targets , mask are right. As defined in official document : tf.contrib.seq2seq.sequence_loss

loss=tf.contrib.seq2seq.sequence_loss(logits=decoder_logits,
                                      targets=decoder_targets,
                                      weights=masks) 

#logits:  [batch_size, sequence_length, num_decoder_symbols]  
#targets: [batch_size, sequence_length] 
#weights: [batch_size, sequence_length] 

Well, it can still work even if the shape are not meet. But the result could be weird (lots of #EOS #PAD... etc).

Since the decoder_outputs, and the decoder_targets might have the same shape as required ( In my case, my decoder_targets has the shape [sequence_length, batch_size] ). So try to use tf.transpose to help you reshape the tensor.

Crumpled answered 4/10, 2018 at 3:6 Comment(1)
i'm facing the exactly same problem , does anyone know how to solve this? Is the sequences loss logits need to be softmaxed?Euphemia
P
0

In my case, it's due to the optimizer, where I mistakenly set a large lr_decay so that it does not work normally anymore.

Check Lr and Optimizer / Scheduler may help.

Pressurize answered 13/4, 2021 at 8:55 Comment(0)
F
0

Reducing learning rate as mentioned in one of the answers above did help to some extent.

Its sort of over-fitting and seems to have come across too many EOS during learning. Had similar issue, initially had fixed 'max_new_tokens' for the model response.

Introducing random int values in a specified range helped a lot to overcome this.

response_generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
}

Within training loop, change "max_new_tokens" for every iteration.

gen_len = randint(30, 60)
response_generation_kwargs["max_new_tokens"] = gen_len
Flavouring answered 23/7 at 5:51 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.