I am trying to write my first neural network to play the game connect four. Im using Java and deeplearning4j. I tried to implement a genetic algorithm, but when i train the network for a while, the outputs of the network jump to NaN and I am unable to tell where I messed up so badly for this to happen.. I will post all 3 classes below, where Game is the game logic and rules, VGFrame the UI and Main all the nn stuff.
I have a pool of 35 neural networks and each iteration i let the best 5 live and breed and randomize the newly created ones a little. To evaluate the networks I let them battle each other and give points to the winner and points for loosing later. Since I penalize putting a stone into a column thats already full I expected the neural networks at least to be able to play the game by the rules after a while but they cant do this. I googled the NaN problem and it seems to be an expoding gradient problem, but from my understanding this shouldn't occur in a genetic algorithm? Any ideas where I could look for the error or whats generally wrong with my implementation?
Main
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;
public class Main {
final int numRows = 7;
final int numColums = 6;
final int randSeed = 123;
MultiLayerNetwork[] models;
static Random random = new Random();
private static final Logger log = LoggerFactory.getLogger(Main.class);
final float learningRate = .8f;
int batchSize = 64; // Test batch size
int nEpochs = 1; // Number of training epochs
// --
public static Main current;
Game mainGame = new Game();
public static void main(String[] args) {
current = new Main();
current.frame = new VGFrame();
current.loadWeights();
}
private VGFrame frame;
private final double mutationChance = .05;
public Main() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER)
.activation(Activation.RELU).seed(randSeed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Nesterovs(0.1, 0.9))
.list()
.layer(new DenseLayer.Builder().nIn(42).nOut(30).activation(Activation.RELU)
.weightInit(WeightInit.XAVIER).build())
.layer(new DenseLayer.Builder().nIn(30).nOut(15).activation(Activation.RELU)
.weightInit(WeightInit.XAVIER).build())
.layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD).nIn(15).nOut(7)
.activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).build())
.build();
models = new MultiLayerNetwork[35];
for (int i = 0; i < models.length; i++) {
models[i] = new MultiLayerNetwork(conf);
models[i].init();
}
}
public void addChip(int i, boolean b) {
if (mainGame.gameState == 0)
mainGame.addChip(i, b);
if (mainGame.gameState == 0) {
float[] f = Main.rowsToInput(mainGame.rows);
INDArray input = Nd4j.create(f);
INDArray output = models[0].output(input);
for (int i1 = 0; i1 < 7; i1++) {
System.out.println(i1 + ": " + output.getDouble(i1));
}
System.out.println("----------------");
mainGame.addChip(Main.getHighestOutput(output), false);
}
getFrame().paint(getFrame().getGraphics());
}
public void newGame() {
mainGame = new Game();
getFrame().paint(getFrame().getGraphics());
}
public void startTraining(int iterations) {
// --------------------------
for (int gameNumber = 0; gameNumber < iterations; gameNumber++) {
System.out.println("Iteration " + gameNumber + " of " + iterations);
float[] evaluation = new float[models.length];
for (int i = 0; i < models.length; i++) {
for (int j = 0; j < models.length; j++) {
if (i != j) {
Game g = new Game();
g.playFullGame(models[i], models[j]);
if (g.gameState == 1) {
evaluation[i] += 45;
evaluation[j] += g.turnNumber;
}
if (g.gameState == 2) {
evaluation[j] += 45;
evaluation[i] += g.turnNumber;
}
}
}
}
float[] evaluationSorted = evaluation.clone();
Arrays.sort(evaluationSorted);
// keep the best 4
int n1 = 0, n2 = 0, n3 = 0, n4 = 0, n5 = 0;
for (int i = 0; i < evaluation.length; i++) {
if (evaluation[i] == evaluationSorted[evaluationSorted.length - 1])
n1 = i;
if (evaluation[i] == evaluationSorted[evaluationSorted.length - 2])
n2 = i;
if (evaluation[i] == evaluationSorted[evaluationSorted.length - 3])
n3 = i;
if (evaluation[i] == evaluationSorted[evaluationSorted.length - 4])
n4 = i;
if (evaluation[i] == evaluationSorted[evaluationSorted.length - 5])
n5 = i;
}
models[0] = models[n1];
models[1] = models[n2];
models[2] = models[n3];
models[3] = models[n4];
models[4] = models[n5];
for (int i = 3; i < evaluationSorted.length; i++) {
// random parent/keep w8ts
double r = Math.random();
if (r > .3) {
models[i] = models[random.nextInt(3)].clone();
} else if (r > .1) {
models[i].setParams(breed(models[random.nextInt(3)], models[random.nextInt(3)]));
}
// Mutate
INDArray params = models[i].params();
models[i].setParams(mutate(params));
}
}
}
private INDArray mutate(INDArray params) {
double[] d = params.toDoubleVector();
for (int i = 0; i < d.length; i++) {
if (Math.random() < mutationChance)
d[i] += (Math.random() - .5) * learningRate;
}
return Nd4j.create(d);
}
private INDArray breed(MultiLayerNetwork m1, MultiLayerNetwork m2) {
double[] d = m1.params().toDoubleVector();
double[] d2 = m2.params().toDoubleVector();
for (int i = 0; i < d.length; i++) {
if (Math.random() < .5)
d[i] += d2[i];
}
return Nd4j.create(d);
}
static int getHighestOutput(INDArray output) {
int x = 0;
for (int i = 0; i < 7; i++) {
if (output.getDouble(i) > output.getDouble(x))
x = i;
}
return x;
}
static float[] rowsToInput(byte[][] rows) {
float[] f = new float[7 * 6];
for (int i = 0; i < 6; i++) {
for (int j = 0; j < 7; j++) {
// f[j + i * 7] = rows[j][i] / 2f;
f[j + i * 7] = (rows[j][i] == 0 ? .5f : rows[j][i] == 1 ? 0f : 1f);
}
}
return f;
}
public void saveWeights() {
log.info("Saving model");
for (int i = 0; i < models.length; i++) {
File resourcesDirectory = new File("src/resources/model" + i);
try {
models[i].save(resourcesDirectory, true);
} catch (IOException e) {
e.printStackTrace();
}
}
}
public void loadWeights() {
if (new File("src/resources/model0").exists()) {
for (int i = 0; i < models.length; i++) {
File resourcesDirectory = new File("src/resources/model" + i);
try {
models[i] = MultiLayerNetwork.load(resourcesDirectory, true);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
System.out.println("col: " + models[0].params().shapeInfoToString());
}
public VGFrame getFrame() {
return frame;
}
}
VGFrame
import java.awt.Color;
import java.awt.Graphics;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import javax.swing.BorderFactory;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JPanel;
import javax.swing.JTextField;
public class VGFrame extends JFrame {
JTextField iterations;
/**
*
*/
private static final long serialVersionUID = 1L;
public VGFrame() {
super("Vier Gewinnt");
this.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
this.setSize(1300, 800);
this.setVisible(true);
JPanel panelGame = new JPanel();
panelGame.setBorder(BorderFactory.createLineBorder(Color.black, 2));
this.add(panelGame);
var handler = new Handler();
var menuHandler = new MenuHandler();
JButton b1 = new JButton("1");
JButton b2 = new JButton("2");
JButton b3 = new JButton("3");
JButton b4 = new JButton("4");
JButton b5 = new JButton("5");
JButton b6 = new JButton("6");
JButton b7 = new JButton("7");
b1.addActionListener(handler);
b2.addActionListener(handler);
b3.addActionListener(handler);
b4.addActionListener(handler);
b5.addActionListener(handler);
b6.addActionListener(handler);
b7.addActionListener(handler);
panelGame.add(b1);
panelGame.add(b2);
panelGame.add(b3);
panelGame.add(b4);
panelGame.add(b5);
panelGame.add(b6);
panelGame.add(b7);
JButton buttonTrain = new JButton("Train");
JButton buttonNewGame = new JButton("New Game");
JButton buttonSave = new JButton("Save Weights");
JButton buttonLoad = new JButton("Load Weights");
iterations = new JTextField("1000");
buttonTrain.addActionListener(menuHandler);
buttonNewGame.addActionListener(menuHandler);
buttonSave.addActionListener(menuHandler);
buttonLoad.addActionListener(menuHandler);
iterations.addActionListener(menuHandler);
panelGame.add(iterations);
panelGame.add(buttonTrain);
panelGame.add(buttonNewGame);
panelGame.add(buttonSave);
panelGame.add(buttonLoad);
this.validate();
}
@Override
public void paint(Graphics g) {
super.paint(g);
if (Main.current.mainGame.rows == null)
return;
var rows = Main.current.mainGame.rows;
for (int i = 0; i < rows.length; i++) {
for (int j = 0; j < rows[0].length; j++) {
if (rows[i][j] == 0)
break;
g.setColor((rows[i][j] == 1 ? Color.yellow : Color.red));
g.fillOval(80 + 110 * i, 650 - 110 * j, 100, 100);
}
}
}
public void update() {
}
}
class Handler implements ActionListener {
@Override
public void actionPerformed(ActionEvent event) {
if (Main.current.mainGame.playersTurn)
Main.current.addChip(Integer.parseInt(event.getActionCommand()) - 1, true);
}
}
class MenuHandler implements ActionListener {
@Override
public void actionPerformed(ActionEvent event) {
switch (event.getActionCommand()) {
case "New Game":
Main.current.newGame();
break;
case "Train":
Main.current.startTraining(Integer.parseInt(Main.current.getFrame().iterations.getText()));
break;
case "Save Weights":
Main.current.saveWeights();
break;
case "Load Weights":
Main.current.loadWeights();
break;
}
}
}
Game
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
public class Game {
int turnNumber = 0;
byte[][] rows = new byte[7][6];
boolean playersTurn = true;
int gameState = 0; // 0:running, 1:Player1, 2:Player2, 3:Draw
public boolean isRunning() {
return this.gameState == 0;
}
public void addChip(int x, boolean player1) {
turnNumber++;
byte b = nextRow(x);
if (b == 6) {
gameState = player1 ? 2 : 1;
return;
}
rows[x][b] = (byte) (player1 ? 1 : 2);
gameState = checkWinner(x, b);
}
private byte nextRow(int x) {
for (byte i = 0; i < rows[x].length; i++) {
if (rows[x][i] == 0)
return i;
}
return 6;
}
// 0 continue, 1 Player won, 2 ai won, 3 Draw
private int checkWinner(int x, int y) {
int color = rows[x][y];
// Vertikal
if (getCount(x, y, 1, 0) + getCount(x, y, -1, 0) >= 3)
return rows[x][y];
// Horizontal
if (getCount(x, y, 0, 1) + getCount(x, y, 0, -1) >= 3)
return rows[x][y];
// Diagonal1
if (getCount(x, y, 1, 1) + getCount(x, y, -1, -1) >= 3)
return rows[x][y];
// Diagonal2
if (getCount(x, y, -1, 1) + getCount(x, y, 1, -1) >= 3)
return rows[x][y];
for (byte[] bs : rows) {
for (byte s : bs) {
if (s == 0)
return 0;
}
}
return 3; // Draw
}
private int getCount(int x, int y, int dirX, int dirY) {
int color = rows[x][y];
int count = 0;
while (true) {
x += dirX;
y += dirY;
if (x < 0 | x > 6 | y < 0 | y > 5)
break;
if (color != rows[x][y])
break;
count++;
}
return count;
}
public void playFullGame(MultiLayerNetwork m1, MultiLayerNetwork m2) {
boolean player1 = true;
while (this.gameState == 0) {
float[] f = Main.rowsToInput(this.rows);
INDArray input = Nd4j.create(f);
this.addChip(Main.getHighestOutput(player1 ? m1.output(input) : m2.output(input)), player1);
player1 = !player1;
}
}
}