I'm trying to train a RNN for digital (audio) signal processing using deeplearning4j. The idea is to have 2 .wav files: one is an audio recording, the second is the same audio recording but processed (for example with a low-pass filter). The RNN's input is the 1st (unprocessed) audio recording, the output is the 2nd (processed) audio recording.
I've used the GravesLSTMCharModellingExample from the dl4j examples, and mostly adapted the CharacterIterator class to accept audio data instead of text.
My 1st project to work with audio at all with dl4j is to basically do the same thing as GravesLSTMCharModellingExample but generating audio instead of text, working with 11025Hz 8 bit mono audio, which works (to some quite amusing results). So the basics wrt working with audio in this context seem to work.
So step 2 was to adapt this for audio processing instead of audio generation.
Unfortunately, I'm not having much success. The best it seems to be able to do is outputting a very noisy version of the input.
As a 'sanity check', I've tested using the same audio file for both the input and the output, which I expected to converge quickly to a model simply copying the input. But it doesn't. Again, after a long time of training, all it seemed to be able to do is produce a noisier version of the input.
The most relevant piece of code I guess is the DataSetIterator.next() method (adapted from the example's CharacterIterator class), which now look like this:
public DataSet next(int num) {
if (exampleStartOffsets.size() == 0)
throw new NoSuchElementException();
int currMinibatchSize = Math.min(num, exampleStartOffsets.size());
// Allocate space:
// Note the order here:
// dimension 0 = number of examples in minibatch
// dimension 1 = size of each vector (i.e., number of characters)
// dimension 2 = length of each time series/example
// Why 'f' order here? See http://deeplearning4j.org/usingrnns.html#data
// section "Alternative: Implementing a custom DataSetIterator"
INDArray input = Nd4j.create(new int[] { currMinibatchSize, columns, exampleLength }, 'f');
INDArray labels = Nd4j.create(new int[] { currMinibatchSize, columns, exampleLength }, 'f');
for (int i = 0; i < currMinibatchSize; i++) {
int startIdx = exampleStartOffsets.removeFirst();
int endIdx = startIdx + exampleLength;
for (int j = startIdx, c = 0; j < endIdx; j++, c++) {
// inputIndices/idealIndices are audio samples converted to indices.
// With 8-bit audio, this translates to values between 0-255.
input.putScalar(new int[] { i, inputIndices[j], c }, 1.0);
labels.putScalar(new int[] { i, idealIndices[j], c }, 1.0);
}
}
return new DataSet(input, labels);
}
So maybe I'm having a fundamental misunderstanding of what LSTMs are supposed to do. Is there anything obviously wrong in the posted code that I'm missing? Is there an obvious reason why training on the same file doesn't necessarily converge quickly to a model that just copies the input? (let alone even trying to train it on signal processing that actually does something?)
I've seen Using RNN to recover sine wave from noisy signal which seems to be about a similar problem (but using a different ML framework), but that didn't get an answer.
Any feedback is appreciated!