deeplearning4j - using an RNN/LSTM for audio signal processing
Asked Answered
J

3

26

I'm trying to train a RNN for digital (audio) signal processing using deeplearning4j. The idea is to have 2 .wav files: one is an audio recording, the second is the same audio recording but processed (for example with a low-pass filter). The RNN's input is the 1st (unprocessed) audio recording, the output is the 2nd (processed) audio recording.

I've used the GravesLSTMCharModellingExample from the dl4j examples, and mostly adapted the CharacterIterator class to accept audio data instead of text.

My 1st project to work with audio at all with dl4j is to basically do the same thing as GravesLSTMCharModellingExample but generating audio instead of text, working with 11025Hz 8 bit mono audio, which works (to some quite amusing results). So the basics wrt working with audio in this context seem to work.

So step 2 was to adapt this for audio processing instead of audio generation.

Unfortunately, I'm not having much success. The best it seems to be able to do is outputting a very noisy version of the input.

As a 'sanity check', I've tested using the same audio file for both the input and the output, which I expected to converge quickly to a model simply copying the input. But it doesn't. Again, after a long time of training, all it seemed to be able to do is produce a noisier version of the input.

The most relevant piece of code I guess is the DataSetIterator.next() method (adapted from the example's CharacterIterator class), which now look like this:

public DataSet next(int num) {
    if (exampleStartOffsets.size() == 0)
        throw new NoSuchElementException();

    int currMinibatchSize = Math.min(num, exampleStartOffsets.size());
    // Allocate space:
    // Note the order here:
    // dimension 0 = number of examples in minibatch
    // dimension 1 = size of each vector (i.e., number of characters)
    // dimension 2 = length of each time series/example
    // Why 'f' order here? See http://deeplearning4j.org/usingrnns.html#data
    // section "Alternative: Implementing a custom DataSetIterator"
    INDArray input = Nd4j.create(new int[] { currMinibatchSize, columns, exampleLength }, 'f');
    INDArray labels = Nd4j.create(new int[] { currMinibatchSize, columns, exampleLength }, 'f');

    for (int i = 0; i < currMinibatchSize; i++) {
        int startIdx = exampleStartOffsets.removeFirst();
        int endIdx = startIdx + exampleLength;

        for (int j = startIdx, c = 0; j < endIdx; j++, c++) {
            // inputIndices/idealIndices are audio samples converted to indices.
            // With 8-bit audio, this translates to values between 0-255.
            input.putScalar(new int[] { i, inputIndices[j], c }, 1.0);
            labels.putScalar(new int[] { i, idealIndices[j], c }, 1.0);
        }
    }

    return new DataSet(input, labels);
}

So maybe I'm having a fundamental misunderstanding of what LSTMs are supposed to do. Is there anything obviously wrong in the posted code that I'm missing? Is there an obvious reason why training on the same file doesn't necessarily converge quickly to a model that just copies the input? (let alone even trying to train it on signal processing that actually does something?)

I've seen Using RNN to recover sine wave from noisy signal which seems to be about a similar problem (but using a different ML framework), but that didn't get an answer.

Any feedback is appreciated!

Jestude answered 6/5, 2017 at 21:44 Comment(4)
Can you answer the following about your project? Why use a LSTM network architecture? Also relevant to providing guidance would be seeing how you are batching the input. Are you performing any type of normalization on it?Hasheem
My thinking behind using an LSTM is that I'm training on data where the sequence of that data matters and I'm hoping the NN will learn something from what data was processed before (unlike something like a 'normal' FF NN).Jestude
As to batching, I'm using a mini-batch size of 32, example size of 10000, and TBPTT length of 1000 (although I'm experimenting a lot with these values).Jestude
@Jestude I am doing ASR project and had posted a question about it in following link : datascience.stackexchange.com/questions/65497/… . It would be great if you can answer part of the question.Sailor
E
1

If you hear distorted version of the input you are on the right way.

The problem might be that your free parameters of the network cannot generalize well on small number of examples. Make sure you have more samples, at least 50_000 which does not overlap each other (not from the same wav file) and try to play with network params, for example try to reduce the nodes on each layer with 10-15% and try with lower learning rate.

Exhort answered 21/4, 2021 at 7:39 Comment(0)
S
1

The most common issue with problems like this is the training data.

  1. Make sure you have enough training data available. If you don't, you can use a library like audiomentations to augment your training set.
  2. Diversity of training data. The more perturbations you can add to your training set, the better.
  3. Hyperparameter optimization - Neural networks in general require a lot of parameter tuning to be able to perform above average. Parameter optimization in deeplearning4j
  4. This is a suggestion based on past experience. It may be out of scope, but an autoencoder architecture usually does wonders for these processing use-cases. (Audio, Images, etc.)
Shoshone answered 9/5, 2021 at 0:2 Comment(0)
U
-1

hello i think in logic for a dataset try to use a long type instead of an integer

public DataSet next(int num)

replace to

public DataSet next(long num)
Uzial answered 19/5, 2017 at 18:5 Comment(1)
That would violate the DataSetIterator interface, and I'm not sure how that is relevant to the problem anyway.Jestude

© 2022 - 2024 — McMap. All rights reserved.