Neural Network returning NaN as output
Asked Answered
I

2

9

I am trying to write my first neural network to play the game connect four. Im using Java and deeplearning4j. I tried to implement a genetic algorithm, but when i train the network for a while, the outputs of the network jump to NaN and I am unable to tell where I messed up so badly for this to happen.. I will post all 3 classes below, where Game is the game logic and rules, VGFrame the UI and Main all the nn stuff.

I have a pool of 35 neural networks and each iteration i let the best 5 live and breed and randomize the newly created ones a little. To evaluate the networks I let them battle each other and give points to the winner and points for loosing later. Since I penalize putting a stone into a column thats already full I expected the neural networks at least to be able to play the game by the rules after a while but they cant do this. I googled the NaN problem and it seems to be an expoding gradient problem, but from my understanding this shouldn't occur in a genetic algorithm? Any ideas where I could look for the error or whats generally wrong with my implementation?

Main

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;

import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;

public class Main {
    final int numRows = 7;
    final int numColums = 6;
    final int randSeed = 123;
    MultiLayerNetwork[] models;

    static Random random = new Random();
    private static final Logger log = LoggerFactory.getLogger(Main.class);
    final float learningRate = .8f;
    int batchSize = 64; // Test batch size
    int nEpochs = 1; // Number of training epochs
    // --
    public static Main current;
    Game mainGame = new Game();

    public static void main(String[] args) {
        current = new Main();
        current.frame = new VGFrame();
        current.loadWeights();
    }

    private VGFrame frame;
    private final double mutationChance = .05;

    public Main() {
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER)
                .activation(Activation.RELU).seed(randSeed)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Nesterovs(0.1, 0.9))
                .list()
                .layer(new DenseLayer.Builder().nIn(42).nOut(30).activation(Activation.RELU)
                        .weightInit(WeightInit.XAVIER).build())
                .layer(new DenseLayer.Builder().nIn(30).nOut(15).activation(Activation.RELU)
                        .weightInit(WeightInit.XAVIER).build())
                .layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD).nIn(15).nOut(7)
                        .activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).build())
                .build();
        models = new MultiLayerNetwork[35];
        for (int i = 0; i < models.length; i++) {
            models[i] = new MultiLayerNetwork(conf);
            models[i].init();
        }

    }

    public void addChip(int i, boolean b) {
        if (mainGame.gameState == 0)
            mainGame.addChip(i, b);
        if (mainGame.gameState == 0) {
            float[] f = Main.rowsToInput(mainGame.rows);
            INDArray input = Nd4j.create(f);
            INDArray output = models[0].output(input);
            for (int i1 = 0; i1 < 7; i1++) {
                System.out.println(i1 + ": " + output.getDouble(i1));
            }
            System.out.println("----------------");
            mainGame.addChip(Main.getHighestOutput(output), false);
        }
        getFrame().paint(getFrame().getGraphics());
    }

    public void newGame() {
        mainGame = new Game();
        getFrame().paint(getFrame().getGraphics());
    }

    public void startTraining(int iterations) {

        // --------------------------
        for (int gameNumber = 0; gameNumber < iterations; gameNumber++) {
            System.out.println("Iteration " + gameNumber + " of " + iterations);
            float[] evaluation = new float[models.length];
            for (int i = 0; i < models.length; i++) {
                for (int j = 0; j < models.length; j++) {
                    if (i != j) {
                        Game g = new Game();
                        g.playFullGame(models[i], models[j]);
                        if (g.gameState == 1) {
                            evaluation[i] += 45;
                            evaluation[j] += g.turnNumber;
                        }
                        if (g.gameState == 2) {
                            evaluation[j] += 45;
                            evaluation[i] += g.turnNumber;
                        }
                    }
                }
            }

            float[] evaluationSorted = evaluation.clone();
            Arrays.sort(evaluationSorted);
            // keep the best 4
            int n1 = 0, n2 = 0, n3 = 0, n4 = 0, n5 = 0;
            for (int i = 0; i < evaluation.length; i++) {
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 1])
                    n1 = i;
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 2])
                    n2 = i;
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 3])
                    n3 = i;
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 4])
                    n4 = i;
                if (evaluation[i] == evaluationSorted[evaluationSorted.length - 5])
                    n5 = i;
            }
            models[0] = models[n1];
            models[1] = models[n2];
            models[2] = models[n3];
            models[3] = models[n4];
            models[4] = models[n5];

            for (int i = 3; i < evaluationSorted.length; i++) {
                // random parent/keep w8ts
                double r = Math.random();
                if (r > .3) {
                    models[i] = models[random.nextInt(3)].clone();

                } else if (r > .1) {
                    models[i].setParams(breed(models[random.nextInt(3)], models[random.nextInt(3)]));
                }
                // Mutate
                INDArray params = models[i].params();
                models[i].setParams(mutate(params));
            }
        }
    }

    private INDArray mutate(INDArray params) {
        double[] d = params.toDoubleVector();
        for (int i = 0; i < d.length; i++) {
            if (Math.random() < mutationChance)
                d[i] += (Math.random() - .5) * learningRate;

        }
        return Nd4j.create(d);
    }

    private INDArray breed(MultiLayerNetwork m1, MultiLayerNetwork m2) {
        double[] d = m1.params().toDoubleVector();
        double[] d2 = m2.params().toDoubleVector();
        for (int i = 0; i < d.length; i++) {
            if (Math.random() < .5)
                d[i] += d2[i];
        }
        return Nd4j.create(d);
    }

    static int getHighestOutput(INDArray output) {
        int x = 0;
        for (int i = 0; i < 7; i++) {
            if (output.getDouble(i) > output.getDouble(x))
                x = i;
        }
        return x;
    }

    static float[] rowsToInput(byte[][] rows) {
        float[] f = new float[7 * 6];
        for (int i = 0; i < 6; i++) {
            for (int j = 0; j < 7; j++) {
                // f[j + i * 7] = rows[j][i] / 2f;
                f[j + i * 7] = (rows[j][i] == 0 ? .5f : rows[j][i] == 1 ? 0f : 1f);
            }
        }
        return f;
    }

    public void saveWeights() {
        log.info("Saving model");
        for (int i = 0; i < models.length; i++) {
            File resourcesDirectory = new File("src/resources/model" + i);
            try {
                models[i].save(resourcesDirectory, true);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    public void loadWeights() {
        if (new File("src/resources/model0").exists()) {
            for (int i = 0; i < models.length; i++) {
                File resourcesDirectory = new File("src/resources/model" + i);
                try {

                    models[i] = MultiLayerNetwork.load(resourcesDirectory, true);
                } catch (IOException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }
            }
        }
        System.out.println("col: " + models[0].params().shapeInfoToString());
    }

    public VGFrame getFrame() {
        return frame;
    }

}

VGFrame

import java.awt.Color;
import java.awt.Graphics;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import javax.swing.BorderFactory;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JPanel;
import javax.swing.JTextField;

public class VGFrame extends JFrame {
    JTextField iterations;
    /**
     * 
     */
    private static final long serialVersionUID = 1L;

    public VGFrame() {
        super("Vier Gewinnt");
        this.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        this.setSize(1300, 800);
        this.setVisible(true);
        JPanel panelGame = new JPanel();
        panelGame.setBorder(BorderFactory.createLineBorder(Color.black, 2));
        this.add(panelGame);

        var handler = new Handler();
        var menuHandler = new MenuHandler();

        JButton b1 = new JButton("1");
        JButton b2 = new JButton("2");
        JButton b3 = new JButton("3");
        JButton b4 = new JButton("4");
        JButton b5 = new JButton("5");
        JButton b6 = new JButton("6");
        JButton b7 = new JButton("7");
        b1.addActionListener(handler);
        b2.addActionListener(handler);
        b3.addActionListener(handler);
        b4.addActionListener(handler);
        b5.addActionListener(handler);
        b6.addActionListener(handler);
        b7.addActionListener(handler);
        panelGame.add(b1);
        panelGame.add(b2);
        panelGame.add(b3);
        panelGame.add(b4);
        panelGame.add(b5);
        panelGame.add(b6);
        panelGame.add(b7);

        JButton buttonTrain = new JButton("Train");
        JButton buttonNewGame = new JButton("New Game");
        JButton buttonSave = new JButton("Save Weights");
        JButton buttonLoad = new JButton("Load Weights");

        iterations = new JTextField("1000");

        buttonTrain.addActionListener(menuHandler);
        buttonNewGame.addActionListener(menuHandler);
        buttonSave.addActionListener(menuHandler);
        buttonLoad.addActionListener(menuHandler);
        iterations.addActionListener(menuHandler);

        panelGame.add(iterations);
        panelGame.add(buttonTrain);
        panelGame.add(buttonNewGame);
        panelGame.add(buttonSave);
        panelGame.add(buttonLoad);

        this.validate();
    }

    @Override
    public void paint(Graphics g) {
        super.paint(g);
        if (Main.current.mainGame.rows == null)
            return;
        var rows = Main.current.mainGame.rows;
        for (int i = 0; i < rows.length; i++) {
            for (int j = 0; j < rows[0].length; j++) {
                if (rows[i][j] == 0)
                    break;

                g.setColor((rows[i][j] == 1 ? Color.yellow : Color.red));
                g.fillOval(80 + 110 * i, 650 - 110 * j, 100, 100);
            }
        }
    }

    public void update() {
    }
}

class Handler implements ActionListener {

    @Override
    public void actionPerformed(ActionEvent event) {
        if (Main.current.mainGame.playersTurn)
            Main.current.addChip(Integer.parseInt(event.getActionCommand()) - 1, true);
    }
}

class MenuHandler implements ActionListener {

    @Override
    public void actionPerformed(ActionEvent event) {
        switch (event.getActionCommand()) {
        case "New Game":
            Main.current.newGame();
            break;
        case "Train":
            Main.current.startTraining(Integer.parseInt(Main.current.getFrame().iterations.getText()));
            break;
        case "Save Weights":
            Main.current.saveWeights();
            break;
        case "Load Weights":
            Main.current.loadWeights();
            break;
        }

    }
}

Game

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class Game {

    int turnNumber = 0;
    byte[][] rows = new byte[7][6];
    boolean playersTurn = true;

    int gameState = 0; // 0:running, 1:Player1, 2:Player2, 3:Draw

    public boolean isRunning() {
        return this.gameState == 0;
    }

    public void addChip(int x, boolean player1) {
        turnNumber++;
        byte b = nextRow(x);
        if (b == 6) {
            gameState = player1 ? 2 : 1;
            return;
        }
        rows[x][b] = (byte) (player1 ? 1 : 2);
        gameState = checkWinner(x, b);
    }

    private byte nextRow(int x) {
        for (byte i = 0; i < rows[x].length; i++) {
            if (rows[x][i] == 0)
                return i;
        }
        return 6;
    }

    // 0 continue, 1 Player won, 2 ai won, 3 Draw
    private int checkWinner(int x, int y) {
        int color = rows[x][y];
        // Vertikal
        if (getCount(x, y, 1, 0) + getCount(x, y, -1, 0) >= 3)
            return rows[x][y];

        // Horizontal
        if (getCount(x, y, 0, 1) + getCount(x, y, 0, -1) >= 3)
            return rows[x][y];

        // Diagonal1
        if (getCount(x, y, 1, 1) + getCount(x, y, -1, -1) >= 3)
            return rows[x][y];
        // Diagonal2
        if (getCount(x, y, -1, 1) + getCount(x, y, 1, -1) >= 3)
            return rows[x][y];
        
        for (byte[] bs : rows) {
            for (byte s : bs) {
                if (s == 0)
                    return 0;
            }
        }
        return 3; // Draw
    }

    private int getCount(int x, int y, int dirX, int dirY) {
        int color = rows[x][y];
        int count = 0;
        while (true) {
            x += dirX;
            y += dirY;
            if (x < 0 | x > 6 | y < 0 | y > 5)
                break;
            if (color != rows[x][y])
                break;
            count++;
        }
        return count;
    }

    public void playFullGame(MultiLayerNetwork m1, MultiLayerNetwork m2) {
        boolean player1 = true;
        while (this.gameState == 0) {
            float[] f = Main.rowsToInput(this.rows);
            INDArray input = Nd4j.create(f);
            this.addChip(Main.getHighestOutput(player1 ? m1.output(input) : m2.output(input)), player1);
            player1 = !player1;
        }
    }
}
Idiom answered 5/2, 2021 at 1:53 Comment(0)
D
7

With a quick look, and based on the analysis of your multiplier variants, it seems like the NaN is produced by an arithmetic underflow, caused by your gradients being too small (too close to absolute 0).

This is the most suspicious part of the code:

 f[j + i * 7] = (rows[j][i] == 0 ? .5f : rows[j][i] == 1 ? 0f : 1f);

If rows[j][i] == 1 then 0f is stored. I don't know how this is managed by the neural network (or even java), but mathematically speaking, a finite-sized float cannot include zero.

Even if your code would alter the 0f with some extra salt, those array values' resultants would have some risk of becoming too close to zero. Due to limited precision when representing real numbers, values very close to zero can not be represented, hence the NaN.

These values have a very friendly name: subnormal numbers.

Any non-zero number with magnitude smaller than the smallest normal number is subnormal.

enter image description here

IEEE_754

As with IEEE 754-1985, The standard recommends 0 for signaling NaNs, 1 for quiet NaNs, so that a signaling NaNs can be quieted by changing only this bit to 1, while the reverse could yield the encoding of an infinity.

Above's text is important here: according to the standard, you are actually specifying a NaN with any 0f value stored.


Even if the name is misleading, Float.MIN_VALUE is a positive value,higher than 0:

enter image description here

The real minimum float value is, in fact: -Float.MAX_VALUE.

Is floating point math subnormal?


Normalizing the gradients

If you check the issue is only because of the 0f values, you could just alter them for other values that represent something similar; Float.MIN_VALUE, Float.MIN_NORMAL, and so on. Something like this, also in other possible parts of the code where this scenario could happen. Take these just as examples, and play with these ranges:

rows[j][i] == 1 ? Float.MIN_VALUE : 1f;

rows[j][i] == 1 ?  Float.MIN_NORMAL : Float.MAX_VALUE/2;

rows[j][i] == 1 ? -Float.MAX_VALUE/2 : Float.MAX_VALUE/2;

Even so, this could also lead to a NaN, based on how these values are altered. If so, you should normalize the values. You could try applying a GradientNormalizer for this. In your network initialization, something like this should be defined, for each layer(or for those who are problematic):

new NeuralNetConfiguration
  .Builder()
  .weightInit(WeightInit.XAVIER)
  (...)
  .layer(new DenseLayer.Builder().nIn(42).nOut(30).activation(Activation.RELU)
        .weightInit(WeightInit.XAVIER)
        .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) //this   
        .build())
  
  (...)

There are different normalizers, so choose which one fits your schema best, and which layers should include one. The options are:

GradientNormalization

  • RenormalizeL2PerLayer

    Rescale gradients by dividing by the L2 norm of all gradients for the layer.

  • RenormalizeL2PerParamType

    Rescale gradients by dividing by the L2 norm of the gradients, separately for each type of parameter within the layer. This differs from RenormalizeL2PerLayer in that here, each parameter type (weight, bias etc) is normalized separately. For example, in a MLP/FeedForward network (where G is the gradient vector), the output is as follows:

    GOut_weight = G_weight / l2(G_weight) GOut_bias = G_bias / l2(G_bias)

  • ClipElementWiseAbsoluteValue

    Clip the gradients on a per-element basis. For each gradient g, set g <- sign(g) max(maxAllowedValue,|g|). i.e., if a parameter gradient has absolute value greater than the threshold, truncate it. For example, if threshold = 5, then values in range -5<g<5 are unmodified; values <-5 are set to -5; values >5 are set to 5.

  • ClipL2PerLayer

    Conditional renormalization. Somewhat similar to RenormalizeL2PerLayer, this strategy scales the gradients if and only if the L2 norm of the gradients (for entire layer) exceeds a specified threshold. Specifically, if G is gradient vector for the layer, then:

    GOut = G if l2Norm(G) < threshold (i.e., no change) GOut = threshold * G / l2Norm(G)

  • ClipL2PerParamType

    Conditional renormalization. Very similar to ClipL2PerLayer, however instead of clipping per layer, do clipping on each parameter type separately. For example in a recurrent neural network, input weight gradients, recurrent weight gradients and bias gradient are all clipped separately.


Here you can find a complete example of the application of these GradientNormalizers.

Duffer answered 5/2, 2021 at 2:55 Comment(6)
Thank you for your detailed explanation. I believed the gradient stuff only matters in backpropagation, is this not correct? I changed my input to Float.Min_Value and added the gradient normalization to all layers, but the problem still occursIdiom
@Idiom hi! I did work on neural networks but that was so many years ago (phew!!). This could also happen if some array fields store absolute zero values. Correct me If I'm wrong, but your gradients don't get too big, right? They're oscillating between 0 and some other small values. I say you what, I won't let this be, I'm intrigued to make it work. So please do not hesitate asking for some info. Are there any divisions here? What is your expected result for the layer's values?Duffer
Also, do you have acess to some stacktrace, in order to know where thing gets complicated? We're going to make this work, no matter what! You have my word, a basque's word. I will kill half of humans to achieve this is needed.Duffer
The 0f and 1f values are representing opposite values? I mean, could they be changed, for example, by -FLOAT_MAX.VALUE/2 and FLOAT_MAX.VALUE/2 , for example? Or eve simpler, by -10f and 10f?. Even so, I think the first think to check is that no 0 float is anywhere near.Duffer
A workaround would also be changing the floats for doubles, as they lack this kind of precission and dont usually overflow for such values. You'll lose precission, but as a test, it could help identifying the main issue. Floating point is broken mateDuffer
@Idiom as an update, take a look at the IEEE spec: As with IEEE 754-1985, The standard recommends 0 for signaling NaNs, and 1 for quiet NaNs -- this may be problematic indeedDuffer
I
1

I think I finally figured it out. I was trying to visualize the network using deeplearning4j-ui, but got some incompatible versions errors. After changing versions I got a new error, stating the networks input is expecting a 2d array and I found on the internet that this is expected across all versions.

So i changed

float[] f = new float[7 * 6];
Nd4j.create(f);

to

float[][] f = new float[1][7 * 6];
Nd4j.createFromArray(f);

And the NaN values finally disappeared. @aran So I guess assuming incorrect inputs was definitly the right direction. Thank you so much for your help :)

Idiom answered 8/2, 2021 at 15:11 Comment(1)
got some info related..will update during these daysDuffer

© 2022 - 2024 — McMap. All rights reserved.