I am creating a chatbot trained on Cornell Movie Dialogs Corpus using NMT.
I am basing my code in part from https://github.com/bshao001/ChatLearner and https://github.com/chiphuyen/stanford-tensorflow-tutorials/tree/master/assignments/chatbot
During training, I print a random output answer fed to the decoder from the batch and the corresponding answer that my model predicts to observe the learning progress.
My issue: After only about 4 iterations of training, the model learns to output the EOS token (<\s>
) for every timestep. It always outputs that as its response (determined using argmax of logits) even as training continues. Once in a while, rarely, the model outputs series of periods as its answer.
I also print the top 10 logit values during training (not just the argmax) to see if maybe the correct word is somewhere in there, but it seems to be predicting the most common words in the vocab (e.g i, you, ?, .). Even these top 10 words don't change much during training.
I have made sure to correctly count input sequence lengths for encoder and decoder, and added SOS (<s>
) and EOS (also used for padding) tokens accordingly. I also perform masking in the loss calculation.
Here is a sample output:
Training iteration 1:
Decoder Input: <s> sure . sure . <\s> <\s> <\s> <\s> <\s> <\s> <\s>
<\s> <\s>
Predicted Answer: wildlife bakery mentality mentality administration
administration winston winston winston magazines magazines magazines
magazines
...
Training iteration 4:
Decoder Input: <s> i guess i had it coming . let us call it settled .
<\s> <\s> <\s> <\s> <\s>
Predicted Answer: <\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s>
<\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s>
After a few more iterations, it settles on only predicting EOS (and rarely some periods)
I am not sure what could be causing this issue and have been stuck on this for a while. Any help would be greatly appreciated!
Update: I let it train for over a hundred thousand iterations and it still only outputs EOS (and occasional periods). The training loss also does not decrease after a few iteration (it remains at around 47 from the beginning)