Getting Started with Deep Learning in Java Using Deep Netts
- July 05, 2022
- 2953 Unique Views
- 3 min read
Deep Netts is pure Java deep learning library with a friendly, Java centric API.
It makes it easy for Java developers to quickly start using deep learning and it is easy to integrate with existing Java applications.
It supports commonly used neural network architectures (feed forward networks, convolutional networks) for classification, regression and image recogniton tasks.
Adding Deep Netts to your Project
To be able to use Deep Netts in Maven based Java project, add the following dependency into dependencies section of your pom.xml file:
<dependency> <groupId>com.deepnetts</groupId> <artifactId>deepnetts-core</artifactId> <version>1.13.2</version> </dependency>
You can also clone the entire library and examples from the GitHub: https://github.com/deepnetts/deepnetts-communityedition
Hello World: Iris Flowers Classifiction
Iris flowers classification problem is commonly used as a "hello world" example for machine learning.
Briefly, we have a CSV file that contains data about 4 atributes which describe Iris flowers (sepal length, sepal width, petal length and petal width), and 3 categories of Iris flowers.
For more details, see https://en.wikipedia.org/wiki/Iris_flower_data_set.
package deepnetts.examples; import deepnetts.data.DataSets; import deepnetts.data.preprocessing.scale.MaxScaler; import deepnetts.data.preprocessing.scale.MinMaxScaler; import deepnetts.eval.ClassifierEvaluator; import deepnetts.eval.ConfusionMatrix; import javax.visrec.ml.eval.EvaluationMetrics; import deepnetts.net.FeedForwardNetwork; import deepnetts.net.layers.activation.ActivationType; import deepnetts.net.loss.LossType; import deepnetts.net.train.BackpropagationTrainer; import deepnetts.net.train.opt.OptimizerType; import deepnetts.util.DeepNettsException; import java.io.IOException; import javax.visrec.ml.data.DataSet; import javax.visrec.ml.data.preprocessing.Scaler; public class IrisFlowersClassifier { public static void main(String[] args) throws DeepNettsException, IOException { int numInputs = 4; // corresponds to number of input features/attribute in data set int numOutputs = 3; // corresponds to number of categories/classes in data set // load iris data set from csv file DataSet dataSet = DataSets.readCsv("datasets/iris.csv", numInputs, numOutputs, true); // scale data to range [0,1] in order to make it suitable for neural network processing Scaler scaler = new MaxScaler(dataSet); scaler.apply(dataSet); // split loaded data into training and test set 60 : 40% ratio DataSet[] trainTestSet = dataSet.split(0.6, 0.4); DataSet trainingSet = trainTestSet[0]; // part of data to use for training DataSet testSet = trainTestSet[1]; // part of data set to use for testing/evaluation // create instance of feed forward neural network (aka multi layer percetpron) using corresponding builder FeedForwardNetwork neuralNet = FeedForwardNetwork.builder() .addInputLayer(numInputs) // input layer accepts inputs from data set, and it's size must correspond to number of inputs in data set .addFullyConnectedLayer(8, ActivationType.RELU) // hidden fully connected layer enables solving more complex problems .addOutputLayer(numOutputs, ActivationType.SOFTMAX) // commonly used activation function in output layer for multi class classification .lossFunction(LossType.CROSS_ENTROPY) // commonly used loss function for multi class classification problems .randomSeed(456) // fix ramdomization seed in order to be able to repeat the results - can use nay value .build(); // get and configure instanceof training algorithm for neural network - backpropagation trainer BackpropagationTrainer trainer = neuralNet.getTrainer(); trainer.setMaxError(0.04f); // training is stopped when thie error valueis reached trainer.setLearningRate(0.01f); // controls the learning step, percent of error used to tune internal weights parametars [0, 0.9] trainer.setOptimizer(OptimizerType.MOMENTUM); // use accelerated optimization method trainer.setMomentum(0.9f); // ammount of acceleration to use // run the training neuralNet.train(trainingSet); // evaluate/test classifier - estimate how it will behave with unseen data ClassifierEvaluator evaluator = new ClassifierEvaluator(); EvaluationMetrics em = evaluator.evaluate(neuralNet, testSet); System.out.println("CLASSIFIER EVALUATION METRICS"); System.out.println(em); // print classifier test results System.out.println("CONFUSION MATRIX"); // print details of the confusion matrix ConfusionMatrix cm = evaluator.getConfusionMatrix(); System.out.println(cm); } }
Full source code of the example is available on the GitHub.
After running this you'll get something like this:
------------------------------------------------------------------------------------- TRAINING NEURAL NETWORK ------------------------------------------------------------------------------------- Epoch:1, Time:4ms, TrainError:0.947609, TrainErrorChange:0.947609, TrainAccuracy: 0.62857145 Epoch:2, Time:2ms, TrainError:0.58163124, TrainErrorChange:-0.36597776, TrainAccuracy: 0.6393557 ... Epoch:280, Time:0ms, TrainError:0.04000282, TrainErrorChange:-5.1554292E-5, TrainAccuracy: 0.97840476 Epoch:281, Time:0ms, TrainError:0.03988598, TrainErrorChange:-1.16840005E-4, TrainAccuracy: 0.97840476 TRAINING COMPLETED Total Training Time: 134ms ------------------------------------------------------------------------ CLASSIFIER EVALUATION METRICS Accuracy: 0.93460923 (How often is classifier correct in total) Precision: 0.96491224 (How often is classifier correct when it gives positive prediction) F1Score: 0.96491224 (Harmonic average (balance) of precision and recall) Recall: 0.96491224 (When it is actually positive class, how often does it give positive prediction) CONFUSION MATRIX none setosaversicolor virginica none 0 0 0 0 setosa 0 21 0 0 versicolor 0 0 20 2 virginica 0 0 0 17
More examples like this that you can use as starter templates for your own AI/machine learning projects in Java are available at:
https://github.com/deepnetts/examples
Tip. A very cool example of how Deep Netts can be used with Apache Groovy to get Python-like development experience created by Paul King is available at https://github.com/paulk-asert/groovy-data-science/tree/master/subprojects/IrisGraalVM.
Links
- Deep Netts Community Edition at GitHub
- Official Deep Netts homepage
- Classifying Iris Flowers with Deep Learning, Groovy and GraalVM, great comparison of various Java DL libraries
Don’t Forget to Share This Post!
Comments (0)
No comments yet. Be the first.