I am a big fan of Keras – Fortunately we also have a similar framework available when we need to implement a solution which works in the JVM: DeepLearning4J. In this Blog I give a quick introduction of some of the key concepts which are needed to get started. I am using the Iris dataset to demonstrate the classification of data with the help of a neural network.

This demo has been implemented in Scala using Jupyter (http://beakerx.com/).

Maven Dependencies

In order to use the functionality we need the
– deeplearning4j-core
– nd4j (which is used for the underlying data model)

%%classpath add mvn 
org.deeplearning4j:deeplearning4j-core:1.0.0-beta2
org.nd4j:nd4j-native-platform:1.0.0-beta2

Imports

I am importing all the classes which are used in the subsequent example:

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
import org.deeplearning4j.datasets.datavec._
import org.deeplearning4j.eval._
import org.deeplearning4j.nn.conf._
import org.deeplearning4j.nn.conf.inputs.InputType
import org.deeplearning4j.nn.conf.layers._
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.nn.weights.WeightInit
import org.deeplearning4j.optimize.listeners.ScoreIterationListener
import org.deeplearning4j.evaluation._

import org.nd4j.linalg.activations.Activation
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator
import org.nd4j.linalg.learning.config.Nesterovs
import org.nd4j.linalg.lossfunctions.LossFunctions
import org.nd4j.linalg.dataset.api.preprocessor._
import org.nd4j.linalg.learning.config._

import org.datavec.api.records.reader.impl.csv._
import org.datavec.api.split._
import org.datavec.api.transform.schema.InferredSchema
import org.datavec.api.transform.TransformProcess
import org.datavec.api.records.reader.impl.transform.TransformProcessRecordReader
import org.datavec.api.transform.schema.Schema

import java.util.Arrays


import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
import org.deeplearning4j.datasets.datavec._
import org.deeplearning4j.eval._
import org.deeplearning4j.nn.conf._
import org.deeplearning4j.nn.conf.inputs.InputType
import org.deeplearning4j.nn.conf.layers._
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.nn.weights.WeightInit
import org.deeplearning4j.optimize.listeners.ScoreIterationListener
import org.deeplearning4j.evaluation._
import org.nd4j.linalg.activations.Activation
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator
import org.nd4j.linalg.learning.config.Nesterovs
import org.nd4j.linalg.lossfunctions.LossFunctions
import org.nd4j.linalg.dataset.api.preprocessor._
import org.nd4j.linalg.learning.config._
import or...

Data Input

We could directly use the IrisDataSetIterator which has been provided by DL4J. I prefer however to demo how to build on data which is available on the Internet: Iris Data Set

If the data is in CSV we can use the CSVRecordReader to load the data. Unfortunatly the content of the loaded file can only contain Numbers: Strings or Booleans (true, false) are not supported.

DataVec (DL4J) can be used to perform the necessary conversions. This is usually done with
the help of Spark. We need to have a Schema defined: If the data is available as CSV file the schema can be inferred with the help of new InferredSchema(file).build().

var dataUrl = "https://gist.githubusercontent.com/netj/8836201/raw/6f9306ad21398ea43cba4f7d537619d0e07d5ae3/iris.csv"
var numLinesToSkip = 1
var delimiter = ','
var csvReader = new CSVRecordReader(numLinesToSkip, delimiter)

org.datavec.api.records.reader.impl.csv.CSVRecordReader@3a192152

We can not load this data directly because it contains the variety column as strings.

The schema inferral does not work in this case because we have no file so we need to set up the schema from scratch.

Then we can define a TransformProcess to convert the variety fields into numbers and finally we use the TransformProcessRecordReader in order to load the converted data.

var schema = new Schema.Builder()
    .addColumnDouble("sepal.length")
    .addColumnDouble("sepal.width")
    .addColumnDouble("petal.length")
    .addColumnDouble("petal.width")
    .addColumnCategorical("variety", Arrays.asList("Setosa","Versicolor","Virginica"))
    .build();

var tp = new TransformProcess.Builder(schema)
    .categoricalToInteger("variety")                            
    .build()

var inputStream = new java.net.URL(dataUrl).openStream()
csvReader.initialize(new InputStreamInputSplit(inputStream))
var tpReader = new TransformProcessRecordReader(csvReader, tp)

org.datavec.api.records.reader.impl.transform.TransformProcessRecordReader@1645db61

Now we can load the data, shuffle it and split it into a training and test set.

var labelIndex = 4     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
var numClasses = 3     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
var batchSize = 150    //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)

var iterator = new RecordReaderDataSetIterator(tpReader, batchSize,labelIndex,numClasses)
var allData = iterator.next()
allData.shuffle()

var testAndTrain = allData.splitTestAndTrain(0.90)  //Use 85% of data for training
var testData = testAndTrain.getTest()
var trainingData = testAndTrain.getTrain()

===========INPUT===================
[[    4.6000,    3.2000,    1.4000,    0.2000], 
[    6.8000,    2.8000,    4.8000,    1.4000], 
[    6.6000,    3.0000,    4.4000,    1.4000], 
[    6.8000,    3.0000,    5.5000,    2.1000], 
[    5.9000,    3.0000,    4.2000,    1.5000], 
[    4.6000,    3.6000,    1.0000,    0.2000], 
[    7.2000,    3.0000,    5.8000,    1.6000], 
[    4.4000,    2.9000,    1.4000,    0.2000], 
[    5.0000,    3.4000,    1.5000,    0.2000], 
[    6.2000,    2.8000,    4.8000,    1.8000], 
[    5.5000,    2.5000,    4.0000,    1.3000], 
[    5.5000,    3.5000,    1.3000,    0.2000], 
[    5.7000,    4.4000,    1.5000,    0.4000], 
[    5.6000,    3.0000,    4.1000,    1.3000], 
[    5.7000,    2.8000,    4.5000,    1.3000], 
[    4.8000,    3.0000,    1.4000,    0.1000], 
[    6.3000,    2.8000,    5.1000,    1.5000], 
[    5.8000,    2.8000,    5.1000,    2.4000], 
[    5.6000,    2.5000,    3.9000,    1.1000], 
[    6.3000,    3.4000,    5.6000,    2.4000], 
[    5.8000,    2.7000,    5.1000,    1.9000], 
[    6.4000,    2.8000,    5.6000,    2.1000], 
[    5.4000,    3.9000,    1.7000,    0.4000], 
[    4.9000,    3.1000,    1.5000,    0.2000], 
[    7.3000,    2.9000,    6.3000,    1.8000], 
[    5.1000,    3.8000,    1.6000,    0.2000], 
[    6.3000,    2.9000,    5.6000,    1.8000], 
[    6.3000,    3.3000,    6.0000,    2.5000], 
[    6.4000,    3.2000,    5.3000,    2.3000], 
[    5.9000,    3.2000,    4.8000,    1.8000], 
[    7.7000,    3.8000,    6.7000,    2.2000], 
[    6.9000,    3.2000,    5.7000,    2.3000], 
[    7.2000,    3.6000,    6.1000,    2.5000], 
[    7.1000,    3.0000,    5.9000,    2.1000], 
[    5.1000,    3.3000,    1.7000,    0.5000], 
[    6.7000,    3.3000,    5.7000,    2.5000], 
[    4.8000,    3.4000,    1.6000,    0.2000], 
[    5.6000,    3.0000,    4.5000,    1.5000], 
[    5.7000,    2.5000,    5.0000,    2.0000], 
[    7.7000,    2.8000,    6.7000,    2.0000], 
[    6.4000,    3.1000,    5.5000,    1.8000], 
[    5.5000,    2.4000,    3.8000,    1.1000], 
[    6.1000,    2.9000,    4.7000,    1.4000], 
[    6.7000,    3.1000,    4.7000,    1.5000], 
[    6.3000,    2.3000,    4.4000,    1.3000], 
[    6.2000,    2.2000,    4.5000,    1.5000], 
[    5.0000,    3.0000,    1.6000,    0.2000], 
[    5.1000,    3.5000,    1.4000,    0.2000], 
[    6.1000,    2.8000,    4.0000,    1.3000], 
[    6.1000,    2.6000,    5.6000,    1.4000], 
[    6.0000,    2.7000,    5.1000,    1.6000], 
[    6.0000,    2.9000,    4.5000,    1.5000], 
[    6.3000,    2.7000,    4.9000,    1.8000], 
[    6.6000,    2.9000,    4.6000,    1.3000], 
[    5.7000,    3.0000,    4.2000,    1.2000], 
[    5.8000,    4.0000,    1.2000,    0.2000], 
[    6.8000,    3.2000,    5.9000,    2.3000], 
[    6.7000,    2.5000,    5.8000,    1.8000], 
[    5.5000,    2.3000,    4.0000,    1.3000], 
[    6.7000,    3.3000,    5.7000,    2.1000], 
[    7.9000,    3.8000,    6.4000,    2.0000], 
[    5.0000,    3.6000,    1.4000,    0.2000], 
[    5.7000,    2.6000,    3.5000,    1.0000], 
[    4.4000,    3.2000,    1.3000,    0.2000], 
[    6.9000,    3.1000,    5.4000,    2.1000], 
[    6.5000,    3.0000,    5.5000,    1.8000], 
[    4.9000,    3.6000,    1.4000,    0.1000], 
[    6.5000,    3.2000,    5.1000,    2.0000], 
[    4.8000,    3.0000,    1.4000,    0.3000], 
[    5.7000,    3.8000,    1.7000,    0.3000], 
[    4.7000,    3.2000,    1.6000,    0.2000], 
[    5.3000,    3.7000,    1.5000,    0.2000], 
[    6.1000,    2.8000,    4.7000,    1.2000], 
[    6.3000,    2.5000,    4.9000,    1.5000], 
[    5.6000,    2.8000,    4.9000,    2.0000], 
[    4.7000,    3.2000,    1.3000,    0.2000], 
[    4.9000,    3.0000,    1.4000,    0.2000], 
[    7.4000,    2.8000,    6.1000,    1.9000], 
[    6.4000,    3.2000,    4.5000,    1.5000], 
[    6.0000,    2.2000,    4.0000,    1.0000], 
[    4.5000,    2.3000,    1.3000,    0.3000], 
[    4.9000,    2.5000,    4.5000,    1.7000], 
[    5.9000,    3.0000,    5.1000,    1.8000], 
[    7.0000,    3.2000,    4.7000,    1.4000], 
[    6.5000,    3.0000,    5.2000,    2.0000], 
[    5.6000,    2.7000,    4.2000,    1.3000], 
[    5.0000,    3.5000,    1.6000,    0.6000], 
[    5.1000,    3.8000,    1.9000,    0.4000], 
[    4.3000,    3.0000,    1.1000,    0.1000], 
[    5.4000,    3.7000,    1.5000,    0.2000], 
[    6.4000,    2.7000,    5.3000,    1.9000], 
[    5.7000,    2.9000,    4.2000,    1.3000], 
[    5.0000,    2.3000,    3.3000,    1.0000], 
[    5.8000,    2.6000,    4.0000,    1.2000], 
[    6.0000,    3.4000,    4.5000,    1.6000], 
[    6.5000,    2.8000,    4.6000,    1.5000], 
[    6.1000,    3.0000,    4.6000,    1.4000], 
[    6.5000,    3.0000,    5.8000,    2.2000], 
[    7.2000,    3.2000,    6.0000,    1.8000], 
[    6.0000,    3.0000,    4.8000,    1.8000], 
[    5.2000,    3.5000,    1.5000,    0.2000], 
[    5.4000,    3.4000,    1.7000,    0.2000], 
[    5.5000,    4.2000,    1.4000,    0.2000], 
[    5.0000,    3.4000,    1.6000,    0.4000], 
[    6.2000,    2.9000,    4.3000,    1.3000], 
[    6.4000,    2.8000,    5.6000,    2.2000], 
[    5.5000,    2.4000,    3.7000,    1.0000], 
[    4.8000,    3.4000,    1.9000,    0.2000], 
[    4.9000,    2.4000,    3.3000,    1.0000], 
[    5.2000,    4.1000,    1.5000,    0.1000], 
[    5.8000,    2.7000,    4.1000,    1.0000], 
[    5.5000,    2.6000,    4.4000,    1.2000], 
[    4.6000,    3.1000,    1.5000,    0.2000], 
[    4.4000,    3.0000,    1.3000,    0.2000], 
[    5.1000,    3.5000,    1.4000,    0.3000], 
[    5.2000,    3.4000,    1.4000,    0.2000], 
[    6.7000,    3.1000,    4.4000,    1.4000], 
[    4.8000,    3.1000,    1.6000,    0.2000], 
[    6.1000,    3.0000,    4.9000,    1.8000], 
[    6.2000,    3.4000,    5.4000,    2.3000], 
[    5.4000,    3.4000,    1.5000,    0.4000], 
[    6.4000,    2.9000,    4.3000,    1.3000], 
[    5.1000,    3.4000,    1.5000,    0.2000], 
[    6.9000,    3.1000,    4.9000,    1.5000], 
[    5.2000,    2.7000,    3.9000,    1.4000], 
[    5.1000,    2.5000,    3.0000,    1.1000], 
[    4.9000,    3.1000,    1.5000,    0.1000], 
[    7.7000,    2.6000,    6.9000,    2.3000], 
[    5.1000,    3.7000,    1.5000,    0.4000], 
[    4.6000,    3.4000,    1.4000,    0.3000], 
[    6.7000,    3.0000,    5.0000,    1.7000], 
[    5.0000,    2.0000,    3.5000,    1.0000], 
[    7.6000,    3.0000,    6.6000,    2.1000], 
[    5.4000,    3.9000,    1.3000,    0.4000], 
[    5.0000,    3.3000,    1.4000,    0.2000]]
=================OUTPUT==================
[[    1.0000,         0,         0], 
[         0,    1.0000,         0], 
[         0,    1.0000,         0], 
[         0,         0,    1.0000], 
[         0,    1.0000,         0], 
[    1.0000,         0,         0], 
[         0,         0,    1.0000], 
[    1.0000,         0,         0], 
[    1.0000,         0,         0], 
[         0,         0,    1.0000], 
[         0,    1.0000,         0], 
[    1.0000,         0,         0], 
[    1.0000,         0,         0], 
[         0,    1.0000,         0], 
[         0,    1.0000,         0], 
[    1.0000,         0,         0], 
[         0,         0,    1.0000], 
[         0,         0,    1.0000], 
[         0,    1.0000,         0], 
[         0,         0,    1.0000], 
[         0,         0,    1.0000], 
[         0,         0,    1.0000], 
[    1.0000,         0,         0], 
[    1.0000,         0,         0], 
[         0,         0,    1.0000], 
[    1.0000,         0,         0], 
[         0,         0,    1.0000], 
[         0,         0,    1.0000], 
[         0,         0,    1.0000], 
[         0,    1.0000,         0], 
[         0,         0,    1.0000], 
[         0,         0,    1.0000], 
[         0,         0,    1.0000], 
[         0,         0,    1.0000], 
[    1.0000,         0,         0], 
[         0,         0,    1.0000], 
[    1.0000,         0,         0], 
[         0,    1.0000,         0], 
[         0,         0,    1.0000], 
[         0,         0,    1.0000], 
[         0,         0,    1.0000], 
[         0,    1.0000,         0], 
[         0,    1.0000,         0], 
[         0,    1.0000,         0], 
[         0,    1.0000,         0], 
[         0,    1.0000,         0], 
[    1.0000,         0,         0], 
[    1.0000,         0,         0], 
[         0,    1.0000,         0], 
[         0,         0,    1.0000], 
[         0,    1.0000,         0], 
[         0,    1.0000,         0], 
[         0,         0,    1.0000], 
[         0,    1.0000,         0], 
[         0,    1.0000,         0], 
[    1.0000,         0,         0], 
[         0,         0,    1.0000], 
[         0,         0,    1.0000], 
[         0,    1.0000,         0], 
[         0,         0,    1.0000], 
[         0,         0,    1.0000], 
[    1.0000,         0,         0], 
[         0,    1.0000,         0], 
[    1.0000,         0,         0], 
[         0,         0,    1.0000], 
[         0,         0,    1.0000], 
[    1.0000,         0,         0], 
[         0,         0,    1.0000], 
[    1.0000,         0,         0], 
[    1.0000,         0,         0], 
[    1.0000,         0,         0], 
[    1.0000,         0,         0], 
[         0,    1.0000,         0], 
[         0,    1.0000,         0], 
[         0,         0,    1.0000], 
[    1.0000,         0,         0], 
[    1.0000,         0,         0], 
[         0,         0,    1.0000], 
[         0,    1.0000,         0], 
[         0,    1.0000,         0], 
[    1.0000,         0,         0], 
[         0,         0,    1.0000], 
[         0,         0,    1.0000], 
[         0,    1.0000,         0], 
[         0,         0,    1.0000], 
[         0,    1.0000,         0], 
[    1.0000,         0,         0], 
[    1.0000,         0,         0], 
[    1.0000,         0,         0], 
[    1.0000,         0,         0], 
[         0,         0,    1.0000], 
[         0,    1.0000,         0], 
[         0,    1.0000,         0], 
[         0,    1.0000,         0], 
[         0,    1.0000,         0], 
[         0,    1.0000,         0], 
[         0,    1.0000,         0], 
[         0,         0,    1.0000], 
[         0,         0,    1.0000], 
[         0,         0,    1.0000], 
[    1.0000,         0,         0], 
[    1.0000,         0,         0], 
[    1.0000,         0,         0], 
[    1.0000,         0,         0], 
[         0,    1.0000,         0], 
[         0,         0,    1.0000], 
[         0,    1.0000,         0], 
[    1.0000,         0,         0], 
[         0,    1.0000,         0], 
[    1.0000,         0,         0], 
[         0,    1.0000,         0], 
[         0,    1.0000,         0], 
[    1.0000,         0,         0], 
[    1.0000,         0,         0], 
[    1.0000,         0,         0], 
[    1.0000,         0,         0], 
[         0,    1.0000,         0], 
[    1.0000,         0,         0], 
[         0,         0,    1.0000], 
[         0,         0,    1.0000], 
[    1.0000,         0,         0], 
[         0,    1.0000,         0], 
[    1.0000,         0,         0], 
[         0,    1.0000,         0], 
[         0,    1.0000,         0], 
[         0,    1.0000,         0], 
[    1.0000,         0,         0], 
[         0,         0,    1.0000], 
[    1.0000,         0,         0], 
[    1.0000,         0,         0], 
[         0,    1.0000,         0], 
[         0,    1.0000,         0], 
[         0,         0,    1.0000], 
[    1.0000,         0,         0], 
[    1.0000,         0,         0]]

We need to normalize our data. For this we can use NormalizeStandardize (which gives us mean 0, unit variance) on the trainsing set and apply it to the test set.

var normalizer = new NormalizerStandardize()
normalizer.fit(trainingData)            // Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData)      // Apply normalization to the training data
normalizer.transform(testData)          // Apply normalization to the test data. This is using statistics calculated from the *training* set
normalizer
org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize@9b6b4e44

Setup the Neural Network Model

The syntax is similar to the one used by Keras. We can define the network in a declarative way:


var numInputs = 4; var outputNum = 3; var seed = 6; var conf = new NeuralNetConfiguration.Builder() .seed(seed) .activation(Activation.TANH) .weightInit(WeightInit.XAVIER) .updater(new Sgd(0.1)) .l2(1e-4) .list() .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3) .build()) .layer(1, new DenseLayer.Builder().nIn(3).nOut(3) .build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .activation(Activation.SOFTMAX) .nIn(3).nOut(outputNum).build()) .backprop(true).pretrain(false) .build();
{
"backprop" : true,
"backpropType" : "Standard",
"cacheMode" : "NONE",
"confs" : [ {
"cacheMode" : "NONE",
"epochCount" : 0,
"iterationCount" : 0,
"layer" : {
"@class" : "org.deeplearning4j.nn.conf.layers.DenseLayer",
"activationFn" : {
"@class" : "org.nd4j.linalg.activations.impl.ActivationTanH"
},
"biasInit" : 0.0,
"biasUpdater" : null,
"constraints" : null,
"dist" : null,
"gradientNormalization" : "None",
"gradientNormalizationThreshold" : 1.0,
"hasBias" : true,
"idropout" : null,
"iupdater" : {
"@class" : "org.nd4j.linalg.learning.config.Sgd",
"learningRate" : 0.1
},
"l1" : 0.0,
"l1Bias" : 0.0,
"l2" : 1.0E-4,
"l2Bias" : 0.0,
"layerName" : "layer0",
"nin" : 4,
"nout" : 3,
"pretrain" : false,
"weightInit" : "XAVIER",
"weightNoise" : null
},
"maxNumLineSearchIterations" : 5,
"miniBatch" : true,
"minimize" : true,
"optimizationAlgo" : "STOCHASTIC_GRADIENT_DESCENT",
"pretrain" : false,
"seed" : 6,
"stepFunction" : null,
"variables" : [ ]
}, {
"cacheMode" : "NONE",
"epochCount" : 0,
"iterationCount" : 0,
"layer" : {
"@class" : "org.deeplearning4j.nn.conf.layers.DenseLayer",
"activationFn" : {
"@class" : "org.nd4j.linalg.activations.impl.ActivationTanH"
},
"biasInit" : 0.0,
"biasUpdater" : null,
"constraints" : null,
"dist" : null,
"gradientNormalization" : "None",
"gradientNormalizationThreshold" : 1.0,
"hasBias" : true,
"idropout" : null,
"iupdater" : {
"@class" : "org.nd4j.linalg.learning.config.Sgd",
"learningRate" : 0.1
},
"l1" : 0.0,
"l1Bias" : 0.0,
"l2" : 1.0E-4,
"l2Bias" : 0.0,
"layerName" : "layer1",
"nin" : 3,
"nout" : 3,
"pretrain" : false,
"weightInit" : "XAVIER",
"weightNoise" : null
},
"maxNumLineSearchIterations" : 5,
"miniBatch" : true,
"minimize" : true,
"optimizationAlgo" : "STOCHASTIC_GRADIENT_DESCENT",
"pretrain" : false,
"seed" : 6,
"stepFunction" : null,
"variables" : [ ]
}, {
"cacheMode" : "NONE",
"epochCount" : 0,
"iterationCount" : 0,
"layer" : {
"@class" : "org.deeplearning4j.nn.conf.layers.OutputLayer",
"activationFn" : {
"@class" : "org.nd4j.linalg.activations.impl.ActivationSoftmax"
},
"biasInit" : 0.0,
"biasUpdater" : null,
"constraints" : null,
"dist" : null,
"gradientNormalization" : "None",
"gradientNormalizationThreshold" : 1.0,
"hasBias" : true,
"idropout" : null,
"iupdater" : {
"@class" : "org.nd4j.linalg.learning.config.Sgd",
"learningRate" : 0.1
},
"l1" : 0.0,
"l1Bias" : 0.0,
"l2" : 1.0E-4,
"l2Bias" : 0.0,
"layerName" : "layer2",
"lossFn" : {
"@class" : "org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood",
"configProperties" : false,
"numOutputs" : -1,
"softmaxClipEps" : 1.0E-10
},
"nin" : 3,
"nout" : 3,
"pretrain" : false,
"weightInit" : "XAVIER",
"weightNoise" : null
},
"maxNumLineSearchIterations" : 5,
"miniBatch" : true,
"minimize" : true,
"optimizationAlgo" : "STOCHASTIC_GRADIENT_DESCENT",
"pretrain" : false,
"seed" : 6,
"stepFunction" : null,
"variables" : [ ]
} ],
"epochCount" : 0,
"inferenceWorkspaceMode" : "ENABLED",
"inputPreProcessors" : { },
"iterationCount" : 0,
"pretrain" : false,
"tbpttBackLength" : 20,
"tbpttFwdLength" : 20,
"trainingWorkspaceMode" : "ENABLED"
}

Run the Model

Now that we have the data available and the model defined we can run the model in a
predefined number of iterations (epochs)

//run the model
var model = new MultiLayerNetwork(conf)
model.init()
model.setListeners(new ScoreIterationListener(100))
var epochs = 1000
for(i <- 0  to epochs) {
model.fit(trainingData)
}
epochs
1000

Evaluate the Model

And finally we can evaluate the perormance of the model:

We can correctly predict the correct variaty in 100% of the cases.

//evaluate the model on the test set
var eval = new Evaluation(3)
var output = model.output(testData.getFeatures())
eval.eval(testData.getLabels(), output)
eval.stats()
========================Evaluation Metrics========================
# of classes:    3
Accuracy:        1.0000
Precision:       1.0000
Recall:          1.0000
F1 Score:        1.0000
Precision, recall & F1: macro-averaged (equally weighted avg. of 3 classes)
=========================Confusion Matrix=========================
0 1 2
-------
3 0 0 | 0 = 0
0 5 0 | 1 = 1
0 0 7 | 2 = 2
Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================

Leave a Reply

Your email address will not be published. Required fields are marked *