pre-train brain js model
Asked Answered
P

2

7

My Question


I just started learning brain js and developed a model which gives you category based on the input text.

I want to ask that each time running the model depends on iterations greater the number of iterations the more it will take time but it improves the accuracy of the model.

Is there any way I can pre-trained my model so it won't let user to wait for the output.

An example will really help me.

My Code


// JSON file data //

[
  {
    "text": "my unit test failed",
    "category": "software"
  },
  {
    "text": "my driver is working",
    "category": "hardware"
  }
]

const brain = require('brain.js');
const data = require('./data.json');                 //data receiving from json//

const network = new brain.recurrent.LSTM();

const trainingData = data.map(item => ({
  input: item.text,
  output: item.category
}));

network.train(trainingData, {
  log: (error) => console.log(error),
  iterations: 1000
});

console.log(network.run('buy me a driver'));         // output is Hardware //

Puddle answered 15/12, 2020 at 7:12 Comment(0)
C
19

You can separate the script into two. In one we train the network with the data, then save it to a JSON file, using the network.toJSON() function.

In the second, we load the network state from the JSON file using the network.fromJSON() function, then run it against our data.

train-network.js

const brain = require('brain.js');
const data = require('./data.json');    
const fs = require("fs");

const network = new brain.recurrent.LSTM();

const trainingData = data.map(item => ({
  input: item.text,
  output: item.category
}));

network.train(trainingData, {
  log: (error) => console.log(error),
  iterations: 1000
});

// Save network state to JSON file.
const networkState = network.toJSON();
fs.writeFileSync("network_state.json",  JSON.stringify(networkState), "utf-8");

load-network.js

const brain = require('brain.js');
const fs = require("fs");

let network = new brain.recurrent.LSTM();

// Load the trained network data from JSON file.
const networkState = JSON.parse(fs.readFileSync("network_state.json", "utf-8").toString());
network.fromJSON(networkState);

console.log(network.run('buy me a driver')); 
Crooked answered 15/12, 2020 at 8:44 Comment(1)
To resolve the error Argument of type 'Buffer' is not assignable to parameter of type 'string' at JSON.parse(), use readFileSync().toString()Journeyman
J
2

You can also save the nn to a function with net.toFunction() and use it elsewhere, as described here

Journeyman answered 1/1, 2023 at 15:33 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.