-
Notifications
You must be signed in to change notification settings - Fork 663
Trainer
The Trainer
makes it easier to train any set to any network, no matter its architecture. To create a trainer you just have to provide a Network
to train.
var trainer = new Trainer(myNetwork);
The trainer also contains built-in tasks to test the performance of your network.
This method allows you to train any training set to a Network
, the training set must be an Array
containing object with an input
and an output
property, for example, this is how you train an XOR to a network using a trainer:
var myNetwork = new Architect.Perceptron(2, 2, 1)
var trainer = new Trainer(myNetwork)
var trainingSet = [
{
input: [0,0],
output: [0]
},
{
input: [0,1],
output: [1]
},
{
input: [1,0],
output: [1]
},
{
input: [1,1],
output: [0]
},
]
trainer.train(trainingSet);
You can also set different options for the training in an object as a second parameter, like:
trainer.train(trainingSet,{
rate: .1,
iterations: 20000,
error: .005,
shuffle: true,
log: 1000,
cost: Trainer.cost.CROSS_ENTROPY
});
-
rate: learning rate to train the network. It can be a static rate (just a number), dynamic (an array of numbers, which will transition from one to the next one according to the number of iterations) or a callback function:
(iterations, error) => rate
. - iterations: maximum number of iterations
- error: minimum error
- shuffle: if true, the training set is shuffled after every iteration, this is useful for training data sequences which order is not meaningful to networks with context memory, like LSTM's.
-
cost: you can set what cost function to use for the training, there are three built-in cost functions (
Trainer.cost.CROSS_ENTROPY
,Trainer.cost.MSE
andTrainer.cost.BINARY
) to choose from cross-entropy or mean squared error. You can also use you own cost function(targetValues, outputValues). - log: this commands the trainer to console.log the error and iterations every X number of iterations.
-
schedule: you can create custom scheduled tasks that will be executed every X number of iterations. It can be used to create custom logs, or to compute analytics based on the data passed to the task (
data
object includeserror
,iterations
and the current learningrate
). If the returned value of the task istrue
, the training will be aborted. This can be used to create special conditions to stop the training (i.e. if the error starts to increase).
schedule: {
every: 500, // repeat this task every 500 iterations
do: function(data) {
// custom log
console.log("error", data.error, "iterations", data.iterations, "rate", data.rate);
if (someCondition)
return true; // abort/stop training
}
}
When the training is done this method returns an object with the error, the iterations, and the elapsed time of the training.
This method works the same way as train, but it uses a WebWorker so the training doesn't affect the user interface (a really long training using the train
method might freeze the UI on the browser, but that doesn't happen using trainAsync
). This method doesn't work in node.js, and it might not work on every browser (it has to support Blob
and WebWorker
's).
var trainer = new Trainer(myNetwork);
trainer.trainAsync(set, options)
.then(results => console.log('done!', results)
It has the same signature and supports the same options as train
, but instead of returning the training results it returns a Promise
that resolves to the training results
This is an example of how to train an XOR using the method trainAsync
:
var myNetwork = new Architect.Perceptron(2, 2, 1)
var trainer = new Trainer(myNetwork)
var trainingSet = [
{
input: [0,0],
output: [0]
},
{
input: [0,1],
output: [1]
},
{
input: [1,0],
output: [1]
},
{
input: [1,1],
output: [0]
},
]
trainer.trainAsync(trainingSet)
.then(results => console.log('done!', results))
This method accepts the same arguments as train(dataSet, options)
. It will iterate over the dataSet, activating the network. It returns the elapsed time and the error (by default, the MSE, but you can specify the cost function in the options
, same way as in train()
).
This method trains an XOR to the network, is useful when you are experimenting with different architectures and you want to test and compare their performances:
var trainer = new Trainer(myNetwork);
trainer.XOR(); // {error: 0.004999821588193305, iterations: 21333, time: 111}
This method trains the network to complete a Discrete Sequence Recall, which is a task for testing context memory in neural networks.
trainer.DSR({
targets: [2,4],
distractors: [3,5],
prompts: [0,1],
length: 10
});
This method trains the network to pass an Embedded Reber Grammar test.
trainer.ERG();
This test challenges the network to complete a timing task.