Deep Learning

This notebook serves as the supporting material for the chapter Deep Learning. In this notebook, we'll learn different activation funtions. Then we'll create a deep neural network using Deeplearning4j and train a model capable of classifying random handwriting digits.

"While handwriting recognition has been attempted by different machine learning algorithms over the years, deep learning performs remarkably well and achieves an accuracy of over 99.7% on the MNIST dataset."

So, let's begin...


In [1]:
%%classpath add mvn
org.nd4j nd4j-native-platform 0.9.1
org.deeplearning4j deeplearning4j-core 0.9.1
org.datavec datavec-api 0.9.1
org.datavec datavec-local 0.9.1
org.datavec datavec-dataframe 0.9.1
org.bytedeco javacpp 1.5
org.apache.httpcomponents httpclient 4.3.5
org.deeplearning4j deeplearning4j-ui_2.11 0.9.1
com.xlson.groovycsv groovycsv 1.3


Activation Functions

1.) Saturating activation funtion


In [7]:
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.api.iter.NdIndexIterator;

import java.util.*;

INDArray array = Nd4j.linspace(-5,5,200);
INDArray sigmoid = Transforms.sigmoid(array);

def ch = new Crosshair(color: Color.gray, width: 2, style: StrokeType.DOT);
p1 = new Plot(title: "Sigmoid activation function", crosshair: ch);
p1 << new ConstantLine(x: 0, y: 0, color: Color.black);
p1 << new ConstantLine(y: 1, color: Color.black, style: StrokeType.DOT);
p1 << new Line(x: [-5, 5], y: [-3/4, 7/4], style: StrokeType.DASH, color: Color.green);
p1 << new Line(x: toDoubleArrayList(array), y: toDoubleArrayList(sigmoid), displayName: "Sigmoid", color: Color.blue, width: 3);
p1 << new Text(x: 0, y: 0.5, text: "Linear", pointerAngle: 3.505);
p1 << new Text(x: -5, y: 0, text: "Saturating", pointerAngle: 1.57);
p1 << new Text(x: 5, y: 1, text: "Saturating", pointerAngle: 4.71);

public List<Double> toDoubleArrayList(INDArray array){
    NdIndexIterator iter = new NdIndexIterator(200);
    List<Double> list = new ArrayList<Double>();
    while (iter.hasNext()) {
        int[] nextIndex = iter.next();
        double nextVal = array.getDouble(nextIndex);
        list.add(nextVal);
    }
    return list;
}

OutputCell.HIDDEN

2.) Nonsaturating Activation Functions


In [9]:
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.api.iter.NdIndexIterator;

import java.util.*;

INDArray array = Nd4j.linspace(-5,5,200);
INDArray relu = Transforms.relu(array);
INDArray leakyRelu = Transforms.leakyRelu(array);
INDArray elu = Transforms.elu(array);

def ch = new Crosshair(color: Color.gray, width: 2, style: StrokeType.DOT);
p1 = new Plot(title: "Non saturating activation function", crosshair: ch);
p1 << new ConstantLine(x: 0, y: 0, color: Color.black);
p1 << new ConstantLine(y: -1, color: Color.black, style: StrokeType.DOT);
p1.getYAxes()[0].setBound(-1.5,5);
p1 << new Line(x: toDoubleArrayList(array), y: toDoubleArrayList(elu), displayName: "ELU (α=1)", color: Color.red)
p1 << new Line(x: toDoubleArrayList(array), y: toDoubleArrayList(relu), displayName: "ReLU", color: Color.orange)
p1 << new Line(x: toDoubleArrayList(array), y: toDoubleArrayList(leakyRelu), displayName: "Leaky ReLU", color: Color.blue);
p1 << new Text(x: -5, y: 0, text: "Leak", pointerAngle: 1.57);


public List<Double> toDoubleArrayList(INDArray array){
    NdIndexIterator iter = new NdIndexIterator(200);
    List<Double> list = new ArrayList<Double>();
    while (iter.hasNext()) {
        int[] nextIndex = iter.next();
        double nextVal = array.getDouble(nextIndex);
        list.add(nextVal);
    }
    return list;
}

OutputCell.HIDDEN

Let's train a neural network on MNIST using the Leaky ReLU.

We've to create a DataUtils class first, containing methods required for downloading, extracting and deleting the dataset files.


In [2]:
package aima.notebooks.deeplearning;

import org.apache.commons.compress.archivers.tar.*;
import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
import org.apache.http.HttpEntity;
import org.apache.http.client.methods.*;
import org.apache.http.impl.client.*;

import java.io.*;
import java.nio.file.*;
import java.nio.file.attribute.BasicFileAttributes;

public class DataUtils{
    
    public DataUtils(){}
    
    public boolean downloadFile(String remoteUrl, String localPath) throws IOException {
        boolean downloaded = false;
        if (remoteUrl == null || localPath == null)
            return downloaded;
        File file = new File(localPath);
        if (!file.exists()) {
            file.getParentFile().mkdirs();
            HttpClientBuilder builder = HttpClientBuilder.create();
            CloseableHttpClient client = builder.build();
            try {
                CloseableHttpResponse response = client.execute(new HttpGet(remoteUrl))
                HttpEntity entity = response.getEntity();
                if (entity != null) {
                    try {
                        FileOutputStream outstream = new FileOutputStream(file)
                        entity.writeTo(outstream);
                        outstream.flush();
                        outstream.close();
                    } catch(IOException e){
                        System.out.println(e);
                    }
                }
            } catch(IOException e){
                System.out.println(e);
            }
            downloaded = true;
        }
        if (!file.exists())
            throw new IOException("File doesn't exist: " + localPath);
        return downloaded;
    }
    public void extractTarGz(String inputPath, String outputPath) throws IOException {
        if (inputPath == null || outputPath == null)
            return;
        final int bufferSize = 4096;
        if (!outputPath.endsWith("" + File.separatorChar))
            outputPath = outputPath + File.separatorChar;
        try {
            TarArchiveInputStream tais = new TarArchiveInputStream(new GzipCompressorInputStream(new BufferedInputStream(new FileInputStream(inputPath))))
            TarArchiveEntry entry;
            while ((entry = (TarArchiveEntry) tais.getNextEntry()) != null) {
                if (entry.isDirectory()) {
                    new File(outputPath + entry.getName()).mkdirs();
                } else {
                    int count;
                    byte[] data = new byte[bufferSize];
                    FileOutputStream fos = new FileOutputStream(outputPath + entry.getName());
                    BufferedOutputStream dest = new BufferedOutputStream(fos, bufferSize);
                    while ((count = tais.read(data, 0, bufferSize)) != -1) {
                        dest.write(data, 0, count);
                    }
                    dest.close();
                }
            }
        } catch(IOException e){
            System.out.println(e);
        }
    }
    public void deleteDir(String path) throws IOException{
        Path directory = Paths.get(path);
        Files.walkFileTree(directory, new SimpleFileVisitor<Path>() {
            @Override
            public FileVisitResult visitFile(Path file, BasicFileAttributes attributes) throws IOException {
                Files.delete(file); // this will work because it's always a File
                return FileVisitResult.CONTINUE;
            }

            @Override
            public FileVisitResult postVisitDirectory(Path dir, IOException exc) throws IOException {
                Files.delete(dir); //this will work because Files in the directory are already deleted
                return FileVisitResult.CONTINUE;
            }
        });
    }
}


Out[2]:
null

Now let's download the MNIST dataset.


In [3]:
import aima.notebooks.deeplearning.DataUtils;
import java.io.File;

String DATA_URL = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz";
String BASE_PATH = "./assets";
String localFilePath = BASE_PATH + "/mnist_png.tar.gz";
DataUtils dataUtils = new DataUtils();
if (!new File(localFilePath).exists()) {
    if (dataUtils.downloadFile(DATA_URL, localFilePath)) {
        dataUtils.extractTarGz(localFilePath, BASE_PATH);
    }
}


Out[3]:
null

In [5]:
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.optimize.listeners.*;
import org.deeplearning4j.optimize.api.InvocationType;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.ui.stats.StatsListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.*;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import java.io.File;
import java.util.Random;

int seed = 123;
double learningRate = 0.01;
int batchSize = 100;
int numEpochs = 1;

int height = 28;
int width = 28;
int channels = 1;
int numInput = height * width;
int numHidden = 1000;
int numOutput = 10;

//Prepare data for loading
File trainData = new File("./assets/mnist_png/training");
FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, new Random(seed));
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); // use parent directory name as the image label
ImageRecordReader trainRR = new ImageRecordReader(height, width, 1, labelMaker);
trainRR.initialize(trainSplit);
DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, numOutput);
DataNormalization imageScaler = new ImagePreProcessingScaler();
imageScaler.fit(trainIter);
trainIter.setPreProcessor(imageScaler);


File testData = new File("./assets/mnist_png/testing");
FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, new Random(seed));
ImageRecordReader testRR = new ImageRecordReader(height, width, 1, labelMaker);
testRR.initialize(testSplit);
DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, numOutput);
testIter.setPreProcessor(imageScaler);

//Build the neural network
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
        .seed(seed)
        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
        .updater(Updater.ADAM)
        .list()
        .layer(0, new DenseLayer.Builder()
                .nIn(numInput)
                .nOut(numHidden)
                .activation(Activation.RELU)
                .weightInit(WeightInit.XAVIER)
                .build())
        .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                .nIn(numHidden)
                .nOut(numOutput)
                .activation(Activation.SOFTMAX)
                .weightInit(WeightInit.XAVIER)
                .build())
        .setInputType(InputType.convolutional(height, width, channels))
        .build();


MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
CollectScoresIterationListener iterationListener = new CollectScoresIterationListener();
model.setListeners(iterationListener);


//Train the model and evaluate
for (int i = 0; i < numEpochs; i++) {
    model.fit(trainIter);
    System.out.println("********Evaluation Stats*********");
    Evaluation eval = model.evaluate(testIter);
    System.out.println(eval.stats());

    trainIter.reset();
    testIter.reset();
}

iterationListener.exportScores(new File("assets/training_scores/dense.csv"));

System.out.println("********Example finished*********");


********Evaluation Stats*********

Examples labeled as 0 classified by model as 0: 960 times
Examples labeled as 0 classified by model as 2: 3 times
Examples labeled as 0 classified by model as 3: 2 times
Examples labeled as 0 classified by model as 4: 2 times
Examples labeled as 0 classified by model as 5: 1 times
Examples labeled as 0 classified by model as 6: 2 times
Examples labeled as 0 classified by model as 7: 5 times
Examples labeled as 0 classified by model as 8: 2 times
Examples labeled as 0 classified by model as 9: 3 times
Examples labeled as 1 classified by model as 1: 1124 times
Examples labeled as 1 classified by model as 2: 2 times
Examples labeled as 1 classified by model as 3: 1 times
Examples labeled as 1 classified by model as 4: 1 times
Examples labeled as 1 classified by model as 5: 1 times
Examples labeled as 1 classified by model as 6: 3 times
Examples labeled as 1 classified by model as 8: 3 times
Examples labeled as 2 classified by model as 0: 3 times
Examples labeled as 2 classified by model as 1: 2 times
Examples labeled as 2 classified by model as 2: 1009 times
Examples labeled as 2 classified by model as 3: 3 times
Examples labeled as 2 classified by model as 4: 2 times
Examples labeled as 2 classified by model as 6: 2 times
Examples labeled as 2 classified by model as 7: 5 times
Examples labeled as 2 classified by model as 8: 5 times
Examples labeled as 2 classified by model as 9: 1 times
Examples labeled as 3 classified by model as 2: 10 times
Examples labeled as 3 classified by model as 3: 978 times
Examples labeled as 3 classified by model as 5: 7 times
Examples labeled as 3 classified by model as 7: 7 times
Examples labeled as 3 classified by model as 8: 4 times
Examples labeled as 3 classified by model as 9: 4 times
Examples labeled as 4 classified by model as 2: 6 times
Examples labeled as 4 classified by model as 4: 971 times
Examples labeled as 4 classified by model as 6: 1 times
Examples labeled as 4 classified by model as 8: 2 times
Examples labeled as 4 classified by model as 9: 2 times
Examples labeled as 5 classified by model as 0: 4 times
Examples labeled as 5 classified by model as 2: 1 times
Examples labeled as 5 classified by model as 3: 14 times
Examples labeled as 5 classified by model as 4: 4 times
Examples labeled as 5 classified by model as 5: 855 times
Examples labeled as 5 classified by model as 6: 6 times
Examples labeled as 5 classified by model as 7: 1 times
Examples labeled as 5 classified by model as 8: 5 times
Examples labeled as 5 classified by model as 9: 2 times
Examples labeled as 6 classified by model as 0: 5 times
Examples labeled as 6 classified by model as 1: 2 times
Examples labeled as 6 classified by model as 2: 4 times
Examples labeled as 6 classified by model as 3: 1 times
Examples labeled as 6 classified by model as 4: 6 times
Examples labeled as 6 classified by model as 5: 5 times
Examples labeled as 6 classified by model as 6: 931 times
Examples labeled as 6 classified by model as 7: 1 times
Examples labeled as 6 classified by model as 8: 3 times
Examples labeled as 7 classified by model as 1: 10 times
Examples labeled as 7 classified by model as 2: 15 times
Examples labeled as 7 classified by model as 3: 4 times
Examples labeled as 7 classified by model as 4: 4 times
Examples labeled as 7 classified by model as 5: 1 times
Examples labeled as 7 classified by model as 7: 972 times
Examples labeled as 7 classified by model as 8: 1 times
Examples labeled as 7 classified by model as 9: 21 times
Examples labeled as 8 classified by model as 0: 3 times
Examples labeled as 8 classified by model as 2: 6 times
Examples labeled as 8 classified by model as 3: 11 times
Examples labeled as 8 classified by model as 4: 8 times
Examples labeled as 8 classified by model as 5: 4 times
Examples labeled as 8 classified by model as 6: 4 times
Examples labeled as 8 classified by model as 7: 7 times
Examples labeled as 8 classified by model as 8: 923 times
Examples labeled as 8 classified by model as 9: 8 times
Examples labeled as 9 classified by model as 0: 3 times
Examples labeled as 9 classified by model as 1: 7 times
Examples labeled as 9 classified by model as 2: 1 times
Examples labeled as 9 classified by model as 3: 9 times
Examples labeled as 9 classified by model as 4: 32 times
Examples labeled as 9 classified by model as 5: 3 times
Examples labeled as 9 classified by model as 7: 4 times
Examples labeled as 9 classified by model as 8: 1 times
Examples labeled as 9 classified by model as 9: 949 times


==========================Scores========================================
 # of classes:    10
 Accuracy:        0.9672
 Precision:       0.9674
 Recall:          0.9669
 F1 Score:        0.9670
Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)
========================================================================
********Example finished*********
Out[5]:
null

In [5]:
import static com.xlson.groovycsv.CsvParser.parseCsv

def file = new File("assets/training_scores/dense.csv");
def csv_content = file.getText('utf-8')
def data_iterator = parseCsv(csv_content, separator: '	', readFirstLine: false)
def iter = [];
def scores = [];
def i = 1;
for(line in data_iterator){iter.add(i); scores.add(line[1]);++i}

def plot = new Plot(title: "Training Curve", yLabel: "Model score", xLabel: "Iteration");
plot << new Line(x: iter, y: scores)
OutputCell.HIDDEN

Optimisation Algorithms

Training a neural network consists of modifying the network's parameters so as to minimize the cost function on the training set. The internal parameters of a model play a very important role in efficiently and effectively training a model and produce accurate results. This is why we use various Optimization algorithms to update and calculate the optimum values of such model’s parameters which influence our model’s learning process and its output. In principle, any kind of optimization algorithm can be used. In practice, we generally use first-order optimization algorithms.

  • First-Order Optimization Algorithms: These algorithms minimize the loss function using its gradient values with respect to the parameters. Most widely used first order optimization algorithm is Gradient Descent. The gradient is almost always calculated using an algorithm called back-propagation.

Gradient indicates the direction of increase of the loss function. As we want to find the minimum point in the valley we need to go in the opposite direction of the gradient. We update parameters in the negative gradient direction to minimize the loss.

Now let's discuss about the various algorithms (also called Updaters or Optimizers) which are used to further optimize the Gradient Descent.

1.) Stochastic Gradient Descent (SGD)

Consider a net with parameters $\theta$ and cost function $L(\theta)$. SGD updates model parameters $\theta$ in the negative direction of gradient $g$ by taking a mini batch of data of size $m$.

$g = \frac{1}{m}\nabla{_\Theta }\sum_i L(f(x^{(i)},\theta),y^{(i)})$

$\theta = \theta - (\epsilon_{k}*g)$

where, the neural network is represented by $f(x^{(i)},\theta)$, $x^{(i)}$ are the training data and $y^{(i)}$ are the training labels, the gradient of the loss $L$ is computed with respect to model parameters $\theta$. The learning rate $\epsilon_{k}$ determines the size of the step that the algorithm takes along the gradient.

2.) Nesterov Momentum

Momentum is like a ball rolling downhill. The ball will gain momentum as it rolls down the hill. Momentum helps accelerate gradient descent by navigating along the relevant directions and dampens the oscillations in irrevalent directions. For updating the weights it takes the gradient of the current step as well as the gradient of the previous time steps. This helps us move faster towards convergence.

$V_{(t)}=\gamma V_{(t-1)} +\epsilon \nabla{_\Theta } \frac{1}{m}\sum_i L(f(x^{(i)},\theta),y^{(i)})$

$\theta = \theta - V_{(t)}$

Common values of momentum parameter $\gamma$ are 0.5 and 0.9

Nesterov acceleration optimization is like a ball rolling down the hill but knows exactly when to slow down before the gradient of the hill increases again. Nesterov accelerated gradient (NAG) is a way to give our momentum term this kind of prescience. We know that we will use our momentum term $\gamma V_{(t-1)}$ to move the parameters $\theta$. Computing $\theta − \gamma V_{(t-1)}$ thus gives us an approximation of the next position of the parameters which gives us a rough idea where our parameters are going to be. We can now effectively look ahead by calculating the gradient not w.r.t. to our current parameters $\theta$ but w.r.t. the approximate future position of our parameters:

$V_{(t)}=\gamma V_{(t-1)} +\epsilon \nabla{_\Theta } \frac{1}{m}\sum_i L(f(x^{(i)},\theta − \gamma V_{(t-1)}),y^{(i)})$

$\theta = \theta - V_{(t)}$

3.) AdaGrad

Adagrad is an adaptive learning rate method. In Adagrad we adapt the learning rate to the parameters. It uses a different learning rate for every parameter $\theta$ at a time step, based on the past gradients which were computed for that parameter. AdaGrad solves this by accumulating squared norms of gradients seen so far and dividing the learning rate by the square root of this sum:

$g = \frac{1}{m}\nabla{_\Theta }\sum_i L(f(x^{(i)},\theta),y^{(i)})$

$s = s + g^{T}g$

$\theta = \theta - \frac{\epsilon_{k}*g}{\sqrt{s + \epsilon}}$

As a result parameters that receive high gradients will have their effective learning rate reduced and parameters that receive small gradients will have their effective learning rate increased. The net effect is greater progress in the more gently sloped directions of parameter space and more cautious updates in the presence of large gradients. For this reason, it is well-suited for dealing with sparse data.

4.) AdaDelta

Main disadvantage of the Adagrad is that its learning rate $\epsilon_k$ is always decreasing and decaying. This happens due to the accumulation of each squared gradients in the denominator, since every added term is positive. The accumulated sum keeps growing during training. This in turn causes the learning rate to shrink and eventually become so small, that the ability of the model to learn fastly decreases, which gives very slow convergence and takes very long to train and learn i.e learning speed suffers and decreases.

This problem of decaying learning Rate is rectified in an algorithm called AdaDelta. Instead of accumulating all previous squared gradients, Adadelta limits the window of accumulated past gradients to some fixed size $w$. Another thing with AdaDelta is that we don’t even need to set a default learning rate.

$\theta_{(t+1)} = \theta_{(t)} + \Delta \theta_{(t)}$

$\Delta\theta = -\frac{RMS[\Delta\theta]_{(t-1)}}{RMS[g]}.g$

5.) RMSProp

RMSProp is Root Mean Square Propagation. It tries to resolve Adagrad’s radically diminishing learning rates by using a exponentially weighted moving average of the squared gradient. By introducing exponentially weighted moving average we are weighing recent past more heavily in comparison to distant past. It utilizes the magnitude of the recent gradient descents to normalize the gradient.

RMSProp divides the learning rate by the average of the exponential decay of squared gradients.

$g = \frac{1}{m}\nabla{_\Theta }\sum_i L(f(x^{(i)},\theta),y^{(i)})$

$s = (decay\_rate*s) + (1-decay\_rate)g^{T}g$

$\theta = \theta - \frac{\epsilon_{k}*g}{\sqrt{s + \epsilon}}$

6.) Adam

Adam stands for Adaptive Moment Estimation. It is another method that computes adaptive learning rates for each parameter. It also reduces the radically diminishing learning rates of Adagrad. It can be viewed as a combination of Adagrad, which works well on sparse gradients and RMSprop which works well in online and nonstationary settings. In addition to storing an exponentially decaying average of past squared gradients like AdaGrad, Adam also keeps an exponentially decaying average of past gradients M(t), similar to momentum.

Adam optimizer is one of the most popular gradient descent optimization algorithms.

$g = \frac{1}{m}\nabla{_\Theta }\sum_i L(f(x^{(i)},\theta),y^{(i)})$

$m = \beta_1m + (1-\beta_1)g$

$s = \beta_2s + (1-\beta_2)g^{T}g$

$\theta = \theta - \frac{\epsilon_{k}*m}{\sqrt{s + \epsilon}}$

The recommended values are $\beta_1 = 0.9$, $\beta_2 = 0.999$, and $\epsilon = 1e-8$.

For comparing the optimizers, a dense neural net is trained on MNIST dataset using 5 different optimizers: SGD, Nesterovs, AdaGrad, RMSProp, Adam. The plot of training loss v/s iterations is shown in a figure below.


In [5]:
import static com.xlson.groovycsv.CsvParser.parseCsv

def file_sgd = new File("assets/training_scores/updaters/sgd.csv");
def csv_content_sgd = file_sgd.getText('utf-8')
def data_iterator_sgd = parseCsv(csv_content_sgd, separator: '	', readFirstLine: false)
def iter_sgd = [];
def scores_sgd = [];
def i_sgd = 1;
for(line_sgd in data_iterator_sgd){iter_sgd.add(i_sgd); scores_sgd.add(line_sgd[1]);++i_sgd}

def file_nesterovs = new File("assets/training_scores/updaters/nesterovs.csv");
def csv_content_nesterovs = file_nesterovs.getText('utf-8')
def data_iterator_nesterovs = parseCsv(csv_content_nesterovs, separator: '	', readFirstLine: false)
def iter_nesterovs = [];
def scores_nesterovs = [];
def i_nesterovs = 1;
for(line_nesterovs in data_iterator_nesterovs){iter_nesterovs.add(i_nesterovs); scores_nesterovs.add(line_nesterovs[1]);++i_nesterovs}

def file_adagrad = new File("assets/training_scores/updaters/adagrad.csv");
def csv_content_adagrad = file_adagrad.getText('utf-8')
def data_iterator_adagrad = parseCsv(csv_content_adagrad, separator: '	', readFirstLine: false)
def iter_adagrad = [];
def scores_adagrad = [];
def i_adagrad = 1;
for(line_adagrad in data_iterator_adagrad){iter_adagrad.add(i_adagrad); scores_adagrad.add(line_adagrad[1]);++i_adagrad}

def file_rmsprop = new File("assets/training_scores/updaters/rmsprop.csv");
def csv_content_rmsprop = file_rmsprop.getText('utf-8')
def data_iterator_rmsprop = parseCsv(csv_content_rmsprop, separator: '	', readFirstLine: false)
def iter_rmsprop = [];
def scores_rmsprop = [];
def i_rmsprop = 1;
for(line_rmsprop in data_iterator_rmsprop){iter_rmsprop.add(i_rmsprop); scores_rmsprop.add(line_rmsprop[1]);++i_rmsprop}

def file_adam = new File("assets/training_scores/updaters/adam.csv");
def csv_content_adam = file_adam.getText('utf-8')
def data_iterator_adam = parseCsv(csv_content_adam, separator: '	', readFirstLine: false)
def iter_adam = [];
def scores_adam = [];
def i_adam = 1;
for(line_adam in data_iterator_adam){iter_adam.add(i_adam); scores_adam.add(line_adam[1]);++i_adam}

def plot = new Plot(title: "Training Curve", yLabel: "Model score", xLabel: "Iteration");
plot << new Line(x: iter_sgd, y: scores_sgd, displayName: "SGD", color: Color.blue, width: 1)
plot << new Line(x: iter_nesterovs, y: scores_nesterovs, displayName: "Nesterovs", color: Color.red, width: 1)
plot << new Line(x: iter_adagrad, y: scores_adagrad, displayName: "AdaGrad", color: Color.orange, width: 1)
plot << new Line(x: iter_rmsprop, y: scores_rmsprop, displayName: "RMSProp", color: Color.green, width: 1)
plot << new Line(x: iter_adam, y: scores_adam, displayName: "Adam", color: Color.black, width: 1)
OutputCell.HIDDEN

We can see from the plot that Adam, RMSProp and Nesterov Momentum optimizers produce the lowest training loss!

Convolutional networks

Convolutional neural networks are the specialized models that are highly efficient for processing information that can be represented in terms of measurements on a grid. This includes images, which are measurements of brightness on a two-dimensional grid, audio waveforms, which can be regarded as a one-dimensional grid across time, and three-dimensional grid data such as 3-D scans used in medical imaging. For a convolutional network, we use 4-dimensional arrays (known as feature map) to keep track of the shape of the image. A feature map is split into several channels. Each channel describes how a single type of feature appears across the entire image. The feature map is of shape $m*h*w*c$ where:

  • $m$ is the number of examples to process together in the same batch,
  • $h$ is the height of the image,
  • $w$ is the width of the image, and
  • $c$ is the number of channels.

Now let's create a convolutional neural network and train it on the MNIST dataset.


In [2]:
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.InvocationType;
import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.*;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import java.io.File;
import java.util.Random;

int seed = 123;
double learningRate = 0.01;
int batchSize = 100;
int numEpochs = 1;

int height = 28;
int width = 28;
int channels = 1;
int numInput = height * width;
int numHidden = 1000;
int numOutput = 10;

//Prepare data for loading
File trainData = new File("./assets/mnist_png/training");
FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, new Random(seed));
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); // use parent directory name as the image label
ImageRecordReader trainRR = new ImageRecordReader(height, width, 1, labelMaker);
trainRR.initialize(trainSplit);
DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, numOutput);
DataNormalization imageScaler = new ImagePreProcessingScaler();
imageScaler.fit(trainIter);
trainIter.setPreProcessor(imageScaler);


File testData = new File("./assets/mnist_png/testing");
FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, new Random(seed));
ImageRecordReader testRR = new ImageRecordReader(height, width, 1, labelMaker);
testRR.initialize(testSplit);
DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, numOutput);
testIter.setPreProcessor(imageScaler);

//Build the neural network
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(seed)
            .updater(Updater.ADAM)
            .weightInit(WeightInit.XAVIER)
            .list()
            .layer(0, new ConvolutionLayer.Builder(5, 5)
                .nIn(channels)
                .stride(1, 1)
                .nOut(20)
                .activation(Activation.IDENTITY)
                .build())
            .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                .kernelSize(2, 2)
                .stride(2, 2)
                .build())
            .layer(2, new ConvolutionLayer.Builder(5, 5)
                .stride(1, 1) // nIn need not specified in later layers
                .nOut(50)
                .activation(Activation.IDENTITY)
                .build())
            .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                .kernelSize(2, 2)
                .stride(2, 2)
                .build())
            .layer(4, new DenseLayer.Builder().activation(Activation.RELU)
                .nOut(500)
                .build())
            .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                .nOut(numOutput)
                .activation(Activation.SOFTMAX)
                .build())
            .setInputType(InputType.convolutionalFlat(height, width, channels)) 
            .build();


MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
CollectScoresIterationListener iterationListener = new CollectScoresIterationListener();
model.setListeners(iterationListener);

//Train the model and evaluate
for (int i = 0; i < numEpochs; i++) {
    model.fit(trainIter);
    System.out.println("********Evaluation Stats*********");
    Evaluation eval = model.evaluate(testIter);
    System.out.println(eval.stats());

    trainIter.reset();
    testIter.reset();
}

iterationListener.exportScores(new File("assets/training_scores/cnn.csv"));

System.out.println("********Example finished*********");


********Evaluation Stats*********

Examples labeled as 0 classified by model as 0: 967 times
Examples labeled as 0 classified by model as 2: 1 times
Examples labeled as 0 classified by model as 5: 2 times
Examples labeled as 0 classified by model as 7: 2 times
Examples labeled as 0 classified by model as 8: 2 times
Examples labeled as 0 classified by model as 9: 6 times
Examples labeled as 1 classified by model as 1: 1132 times
Examples labeled as 1 classified by model as 3: 1 times
Examples labeled as 1 classified by model as 5: 1 times
Examples labeled as 1 classified by model as 6: 1 times
Examples labeled as 2 classified by model as 0: 1 times
Examples labeled as 2 classified by model as 1: 2 times
Examples labeled as 2 classified by model as 2: 1015 times
Examples labeled as 2 classified by model as 3: 3 times
Examples labeled as 2 classified by model as 4: 1 times
Examples labeled as 2 classified by model as 5: 1 times
Examples labeled as 2 classified by model as 7: 7 times
Examples labeled as 2 classified by model as 8: 1 times
Examples labeled as 2 classified by model as 9: 1 times
Examples labeled as 3 classified by model as 2: 1 times
Examples labeled as 3 classified by model as 3: 990 times
Examples labeled as 3 classified by model as 5: 8 times
Examples labeled as 3 classified by model as 7: 2 times
Examples labeled as 3 classified by model as 9: 9 times
Examples labeled as 4 classified by model as 4: 975 times
Examples labeled as 4 classified by model as 9: 7 times
Examples labeled as 5 classified by model as 0: 1 times
Examples labeled as 5 classified by model as 2: 1 times
Examples labeled as 5 classified by model as 3: 4 times
Examples labeled as 5 classified by model as 5: 883 times
Examples labeled as 5 classified by model as 7: 1 times
Examples labeled as 5 classified by model as 9: 2 times
Examples labeled as 6 classified by model as 0: 10 times
Examples labeled as 6 classified by model as 1: 3 times
Examples labeled as 6 classified by model as 2: 1 times
Examples labeled as 6 classified by model as 3: 1 times
Examples labeled as 6 classified by model as 4: 16 times
Examples labeled as 6 classified by model as 5: 27 times
Examples labeled as 6 classified by model as 6: 899 times
Examples labeled as 6 classified by model as 8: 1 times
Examples labeled as 7 classified by model as 1: 4 times
Examples labeled as 7 classified by model as 2: 5 times
Examples labeled as 7 classified by model as 3: 4 times
Examples labeled as 7 classified by model as 7: 1012 times
Examples labeled as 7 classified by model as 9: 3 times
Examples labeled as 8 classified by model as 0: 4 times
Examples labeled as 8 classified by model as 2: 5 times
Examples labeled as 8 classified by model as 3: 2 times
Examples labeled as 8 classified by model as 4: 4 times
Examples labeled as 8 classified by model as 5: 8 times
Examples labeled as 8 classified by model as 6: 1 times
Examples labeled as 8 classified by model as 7: 4 times
Examples labeled as 8 classified by model as 8: 919 times
Examples labeled as 8 classified by model as 9: 27 times
Examples labeled as 9 classified by model as 1: 2 times
Examples labeled as 9 classified by model as 2: 1 times
Examples labeled as 9 classified by model as 3: 1 times
Examples labeled as 9 classified by model as 4: 6 times
Examples labeled as 9 classified by model as 5: 2 times
Examples labeled as 9 classified by model as 7: 4 times
Examples labeled as 9 classified by model as 9: 993 times


==========================Scores========================================
 # of classes:    10
 Accuracy:        0.9785
 Precision:       0.9786
 Recall:          0.9781
 F1 Score:        0.9781
Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)
========================================================================
********Example finished*********
Out[2]:
null

In [7]:
import static com.xlson.groovycsv.CsvParser.parseCsv

def file = new File("assets/training_scores/cnn.csv");
def csv_content = file.getText('utf-8')
def data_iterator = parseCsv(csv_content, separator: '	', readFirstLine: false)
def iter = [];
def scores = [];
def i = 1;
for(line in data_iterator){iter.add(i); scores.add(line[1]);++i}

def plot = new Plot(title: "Training Curve", yLabel: "Model score", xLabel: "Iteration");
plot << new Line(x: iter, y: scores)
OutputCell.HIDDEN

We can now delete the MNIST dataset files as they are no longer required.


In [4]:
import aima.notebooks.deeplearning.DataUtils;
import java.io.File;

String DATA_URL = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz";
String BASE_PATH = "./assets";
String localFilePath = BASE_PATH + "/mnist_png.tar.gz";

File file = new File(localFilePath);
file.delete();
DataUtils dataUtils = new DataUtils();
dataUtils.deleteDir(BASE_PATH + "/mnist_png");


Out[4]:
null

Recurrent Neural Networks

Recurrent neural networks are the networks that introduce the concept of time i.e they allow us to define the value of some variable $v$ at time step $t$ in terms of the values of this variable at previous time steps. For example, we can define an update rule: $v_{(t)} = f(v_{(t-1)})$ using some function $f$ of our choice. These networks are particularly well suited for sequence processing tasks as they allow us to operate over the sequences of vectors: sequences in the input, the output, or in the most general case both. In the last few years, there has been incredible success applying RNNs to a variety of problems such as speech recognition, language modeling, translation, image captioning and the list goes on.

Here, we'll apply RNN to a simple problem of generating text character by character. So, let's start...


In [2]:
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import java.util.*;

int seed = 123;
int nHidden = 50;
int epochs = 15;

//Define a sentence to learn
//Add a dummy character in beginning so that the RNN learns the complete sentence.
char[] LEARNSTRING = "*The quick brown fox jumps over a lazy dog.".toCharArray();

LinkedHashSet<Character> LEARNSTRING_CHARS = new LinkedHashSet<>();
for (char c : LEARNSTRING) LEARNSTRING_CHARS.add(c);
List<Character> LEARNSTRING_CHARS_LIST = new ArrayList<>();
LEARNSTRING_CHARS_LIST.addAll(LEARNSTRING_CHARS);


//Build the neural network
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
        .seed(seed)
        .updater(Updater.ADAM)
        .weightInit(WeightInit.XAVIER)
        .list()
        .layer(0, new LSTM.Builder()
                .nIn(LEARNSTRING_CHARS.size())
                .nOut(nHidden)
                .activation(Activation.TANH)
                .build())
        .layer(1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                .activation(Activation.SOFTMAX)
                .nIn(nHidden)
                .nOut(LEARNSTRING_CHARS.size())
                .build())
        .build();

MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
CollectScoresIterationListener iterationListener = new CollectScoresIterationListener();
model.setListeners(iterationListener);

//Create our training data
int[] shape = [1, LEARNSTRING_CHARS_LIST.size(), LEARNSTRING.length]
INDArray input = Nd4j.zeros(shape);
INDArray labels = Nd4j.zeros(shape);

int pos = 0;
for (char currChar : LEARNSTRING) {
    char nextChar = LEARNSTRING[(pos + 1) % (LEARNSTRING.length)]; //When currChar is the last, take the first character as nextChar.
    // Input neuron for current character is 1 at "pos"
    int[] inputArr = [0, LEARNSTRING_CHARS_LIST.indexOf(currChar), pos];
    input.putScalar(inputArr, 1);

    // Output neuron for next character is 1 at "pos"
    int[] labelArr = [0, LEARNSTRING_CHARS_LIST.indexOf(nextChar), pos];
    labels.putScalar(labelArr, 1);
    pos++;
}

DataSet trainingData = new DataSet(input, labels);

//Train the model and evaluate
for (int i = 0; i < epochs; i++) {
    model.fit(trainingData);
    model.rnnClearPreviousState();

    System.out.print("Epoch " + i + " completed. Sample:\t");
    //Evaluate
    //Put the first character into RNN as an initialisation
    int[] testShape = [1, LEARNSTRING_CHARS_LIST.size(), 1]
    INDArray testInit = Nd4j.zeros(testShape);
    testInit.putScalar(LEARNSTRING_CHARS_LIST.indexOf(LEARNSTRING[0]), 1);

    INDArray output = model.rnnTimeStep(testInit);
    //output now contains the highest value neuron at such a position which is the index of a character, which the model thinks should come next
    //now the model should guess (LEARNSTRING.length - 1) more characters...

    for (int j = 0; j < LEARNSTRING.length - 1; j++) {

        //First let's process the last output of the model.
        int sampledCharacterIndex = Nd4j.getExecutioner().exec(new IMax(output, null, 1),1).getAt(0);
        System.out.print(LEARNSTRING_CHARS_LIST.get(sampledCharacterIndex));

        //Use the last output as next input
        int[] nextInputShape = [1, LEARNSTRING_CHARS_LIST.size(), 1];
        INDArray nextInput = Nd4j.zeros(nextInputShape);
        nextInput.putScalar(sampledCharacterIndex, 1);
        output = model.rnnTimeStep(nextInput);
    }
    System.out.println();
}

iterationListener.exportScores(new File("assets/training_scores/rnn.csv"));


Epoch 0 completed. Sample:	                                          
Epoch 1 completed. Sample:	                                          
Epoch 2 completed. Sample:	e   aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
Epoch 3 completed. Sample:	eequuikk uukuukkuukuukuukuukkuukuukuukuukk
Epoch 4 completed. Sample:	The    orororororororororororororororororo
Epoch 5 completed. Sample:	The     o oooooooooooooooooooooooooooooooo
Epoch 6 completed. Sample:	Thhhq       o ogg.g.*jupssooo ogg.jmpssooo
Epoch 7 completed. Sample:	Theqqucccc    o o o o o o o o o o o o o o 
Epoch 8 completed. Sample:	Thequiccck  o o o o o o o o o o o o o o o 
Epoch 9 completed. Sample:	The quick brove o ox jumps ove o o ox jump
Epoch 10 completed. Sample:	The quick br o  ox jumps ove a ove a ove a
Epoch 11 completed. Sample:	The quick br a lazy dog.***Tee  uuick br a
Epoch 12 completed. Sample:	The quick brown fox jumps over a lazy dog.
Epoch 13 completed. Sample:	The quick brown fox jumps over a lazy dog.
Epoch 14 completed. Sample:	The quick brown fox jumps over a lazy dog.
Out[2]:
null

In [4]:
import static com.xlson.groovycsv.CsvParser.parseCsv

def file = new File("assets/training_scores/rnn.csv");
def csv_content = file.getText('utf-8')
def data_iterator = parseCsv(csv_content, separator: '	', readFirstLine: false)
def iter = [];
def scores = [];
def i = 1;
for(line in data_iterator){iter.add(i); scores.add(line[1]);++i}

def plot = new Plot(title: "Training Curve", yLabel: "Model score", xLabel: "Iteration");
plot << new Line(x: iter, y: scores)
OutputCell.HIDDEN

TensorFlow Java API

TensorFlow is an open source library for dataflow programming. In this notebook, we’ll go through the basics of TensorFlow and how to use it in Java. Please note that the TensorFlow Java API does not have feature parity with the Python API. Hence, the Java API is most suitable for inference using pre-trained models and for training pre-defined models from a single Java process. We’ll cover later in the notebook, the possible use cases for using the TensorFlow Java API.

TensorFlow Graphs and Sessions

TensorFlow computation basically revolves around 2 fundamental concepts: Graph and Session.

Computations are represented as Graphs in TensorFlow. A TensorFlow computational graph is typically a directed acyclic graph of operation and data. It consists of two elements:

  • Tensor: These are the core unit of data in TensorFlow. They are represented as the edges in a computational graph, depicting the flow of data through the graph. A tensor can have a shape with any number of dimensions. The number of dimensions in a tensor is usually referred to as its rank. So a scalar is a rank 0 tensor, a vector is a rank 1 tensor, a matrix is a rank 2 tensor, and so on and so forth.
  • Operation: These are the nodes in a computational graph. They refer to a wide variety of computation that can happen on the tensors feeding into the operation. They often result in tensors as well which emanate from the operation in a computational graph.

Now, a TensorFlow graph is a mere schematic of the computation. It doesn’t compute anything, it doesn’t hold any values, it just defines the operations that you specified in your code. Such a graph must be run inside what is called a TensorFlow session for the tensors in the graph to be evaluated.

Let's try to build a graph and run it in a session using the TensorFlow Java API. More precisely, we’ll be using TensorFlow Java API to solve the function represented by the following equation: $y = w*x + b$ where constants $w$ and $b$ have the values 3.0 and 2.0 respectively. We'll begin with adding the required dependencies.


In [6]:
%%classpath add mvn
org.tensorflow tensorflow 1.12.0
org.tensorflow proto 1.12.0
com.google.guava guava 23.6-jre



In [10]:
import org.tensorflow.*;


//Creating the graph
Graph graph = new Graph();

//Defining constants
Operation w = graph.opBuilder("Const", "w")
        .setAttr("dtype", DataType.INT32)
        .setAttr("value", Tensor.create(3,Integer.class))
        .build();

Operation b = graph.opBuilder("Const", "b")
        .setAttr("dtype", DataType.INT32)
        .setAttr("value", Tensor.create(2, Integer.class))
        .build();

//Defining placeholders
Operation x = graph.opBuilder("Placeholder", "input")
        .setAttr("dtype", DataType.INT32)
        .build();

//Defining functions
Operation wx = graph.opBuilder("Mul", "wx")
        .addInput(w.output(0))
        .addInput(x.output(0))
        .build();

Operation y = graph.opBuilder("Add", "y")
        .addInput(wx.output(0))
        .addInput(b.output(0))
        .build();

//Running a session
Session session = new Session(graph);
Tensor<Integer> tensor = session.runner().fetch("y")
        .feed("input", Tensor.create(4, Integer.class))
        .run().get(0).expect(Integer.class);

//Expected ans: (3*4)+2 = 14
System.out.println(tensor.intValue());


14
Out[10]:
null

Now let's train a model using TensorFlow Java API. As mentioned earlier, the Java API is suitable for training pre-defined models, therefore we'll use the file graph.pb, which is generated by executing create_graph.py in Python. The model in graph.pb represents a simple linear model: $y = w*x + b$. The training generates the data of the form $y = 3.0*x + 2.0$, and over time, the model should learn and the value of $w$ should converge to 3.0, and $b$ to 2.0.


In [11]:
import org.tensorflow.*;

import java.nio.file.*;
import java.util.*;

final byte[] graphDef = Files.readAllBytes(Paths.get("./assets/training/graph.pb"));

try {
    Graph graph = new Graph();
     Session sess = new Session(graph); 
    graph.importGraphDef(graphDef);

    sess.runner().addTarget("init").run();

    System.out.print("Starting from       : ");
    printVariables(sess);

    // Train a bunch of times.
    final Random r = new Random();
    final int NUM_EXAMPLES = 500;
    for (int i = 1; i <= 5; i++) {
        for (int n = 0; n < NUM_EXAMPLES; n++) {
            Float inp = r.nextFloat();
            try {
                Tensor<Float> input = Tensors.create((float)inp);
                 Tensor<Float> target = Tensors.create((float)(3 * inp + 2));
                sess.runner().feed("input", input).feed("target", target).addTarget("train").run();
            } catch (Exception e){
                e.printStackTrace();
            }
        }
        System.out.printf("After %5d examples: ", i * NUM_EXAMPLES);
        printVariables(sess);
    }

    // Example of "inference" in the same graph:
    try {
        Tensor<Float> input = Tensors.create(1.0f);
         Tensor<Float> output =
                 sess.runner().feed("input", input).fetch("output").run().get(0).expect(Float.class);
        System.out.printf(
                "For input %f, produced %f (ideally would produce 3*%f + 2)\n",
                input.floatValue(), output.floatValue(), input.floatValue());
    } catch (Exception e){
        e.printStackTrace();
    }
} catch (Exception e){
    e.printStackTrace();
}

private void printVariables(Session sess) {
    List<Tensor<?>> values = sess.runner().fetch("W/read").fetch("b/read").run();
    System.out.printf("W = %f\tb = %f\n", values.get(0).floatValue(), values.get(1).floatValue());
    for (Tensor<?> t : values) {
        t.close();
    }
}


Starting from       : W = 5.000000	b = 3.000000
After   500 examples: W = 3.582906	b = 1.685713
After  1000 examples: W = 3.307192	b = 1.835320
After  1500 examples: W = 3.160639	b = 1.911917
After  2000 examples: W = 3.083861	b = 1.955856
After  2500 examples: W = 3.041650	b = 1.976595
For input 1.000000, produced 5.018245 (ideally would produce 3*1.000000 + 2)
Out[11]:
java.io.PrintStream@79528d10

Till now, we've learned to perform the basic operations using the TensorFlow Java API. But of course, TensorFlow is meant to run graphs much much larger than this. Additionally, the tensors it deals within real-world models are much larger in size and rank. These are the actual machine learning models where TensorFlow finds its real use.

Working with the core API in TensorFlow can become really cumbersome as the size of the graph increases. Therefore, TensorFlow provides high-level APIs like Keras to work with complex models. Unfortunately, there is little to no official support for Keras on Java yet. However, we can use Python to define and train complex models either directly in TensorFlow or using high-level APIs like Keras. Subsequently, we can export a trained model and use that in Java using the TensorFlow Java API. This would particularly be useful for situations where we want to use machine learning enabled features in existing clients running on Java. Then there could be several instances where we are interested in the output of a machine learning model but do not necessarily want to create and train that model in Java. This is where TensorFlow Java API finds the bulk of its use.

Let's learn how this can be achieved, through an example of image classification. We'll be using a pre-trained model inception5h (developed by Google) for classifying the image shown below.


In [12]:
import org.tensorflow.*;

import java.io.*;
import java.nio.file.*;
import java.util.*;

final List<String> labels = new ArrayList<>();
String line;
BufferedReader reader = new BufferedReader(new FileReader("./assets/image_classification/labels.txt"));
while ((line = reader.readLine()) != null) {
    labels.add(line);
}

final byte[] graphDef = Files.readAllBytes(Paths.get("./assets/image_classification/graph.pb"));
try{
    Graph graph = new Graph();
    Session session = new Session(graph);
    graph.importGraphDef(graphDef);
    
    String filename = "./assets/image_classification/chair.jpeg";
    float[] probabilities = null;
    byte[] bytes = Files.readAllBytes(Paths.get(filename));
    try {
        Tensor<String> input = Tensors.create(bytes);
        Tensor<Float> output =
                 session
                         .runner()
                         .feed("encoded_image_bytes", input)
                         .fetch("probabilities")
                         .run()
                         .get(0)
                         .expect(Float.class)
        if (probabilities == null) {
            probabilities = new float[(int) output.shape()[0]];
        }
        output.copyTo(probabilities);
        int label = argmax(probabilities);
        System.out.printf(
                "%-30s --> %-15s (%.2f%% likely)\n",
                filename, labels.get(label), probabilities[label] * 100.0);
    } catch(Exception e){
        e.printStackTrace();
    }
} catch(Exception e){
    e.printStackTrace();
}

private int argmax(float[] probabilities) {
    int best = 0;
    for (int i = 1; i < probabilities.length; ++i) {
        if (probabilities[i] > probabilities[best]) {
            best = i;
        }
    }
    return best;
}


./assets/image_classification/chair.jpeg --> studio couch    (94.77% likely)
Out[12]:
java.io.PrintStream@79528d10

Let's test this model on other images as well...


In [13]:
import org.tensorflow.*;

import java.io.*;
import java.nio.file.*;
import java.util.*;

final List<String> labels = new ArrayList<>();
String line;
BufferedReader reader = new BufferedReader(new FileReader("./assets/image_classification/labels.txt"));
while ((line = reader.readLine()) != null) {
    labels.add(line);
}

final byte[] graphDef = Files.readAllBytes(Paths.get("./assets/image_classification/graph.pb"));
try{
    Graph graph = new Graph();
    Session session = new Session(graph);
    graph.importGraphDef(graphDef);
    
    List<String> filename = new ArrayList<>();
    filename.add("./assets/image_classification/terrier.jpg");
    filename.add("./assets/image_classification/terrier2.jpg");
    filename.add("./assets/image_classification/porcupine.jpg");
    filename.add("./assets/image_classification/whale.jpg");
    
    for(int i = 0; i < filename.size(); ++i){
        float[] probabilities = null;
        byte[] bytes = Files.readAllBytes(Paths.get(filename.get(i)));
        try {
            Tensor<String> input = Tensors.create(bytes);
            Tensor<Float> output =
                     session
                             .runner()
                             .feed("encoded_image_bytes", input)
                             .fetch("probabilities")
                             .run()
                             .get(0)
                             .expect(Float.class)
            if (probabilities == null) {
                probabilities = new float[(int) output.shape()[0]];
            }
            output.copyTo(probabilities);
            int label = argmax(probabilities);
            System.out.printf(
                    "%-30s --> %-15s (%.2f%% likely)\n",
                    filename.get(i), labels.get(label), probabilities[label] * 100.0);
        } catch(Exception e){
            e.printStackTrace();
        }
    }
} catch(Exception e){
    e.printStackTrace();
}

private int argmax(float[] probabilities) {
    int best = 0;
    for (int i = 1; i < probabilities.length; ++i) {
        if (probabilities[i] > probabilities[best]) {
            best = i;
        }
    }
    return best;
}


./assets/image_classification/terrier.jpg --> Australian terrier (67.63% likely)
./assets/image_classification/terrier2.jpg --> Tibetan terrier (44.09% likely)
./assets/image_classification/porcupine.jpg --> porcupine       (85.03% likely)
./assets/image_classification/whale.jpg --> killer whale    (41.59% likely)
Out[13]:
null

Let's see another example of using a pre-trained model. Here, we'll use the ssd_inception_v2_coco model for object detection.


In [6]:
import org.tensorflow.*;
import org.tensorflow.framework.*;
import org.tensorflow.types.UInt8;

import javax.imageio.ImageIO;
import javax.swing.*;
import java.awt.*;
import java.awt.image.*;
import java.io.*;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

final List<String> labels = new ArrayList<>();
String line;
BufferedReader reader = new BufferedReader(new FileReader("./assets/object_detection/object_label.txt"));
while ((line = reader.readLine()) != null) {
    labels.add(line);
}

try{
    SavedModelBundle model = SavedModelBundle.load("./assets/object_detection/ssd_inception_v2_coco_2017_11_17/saved_model", "serve");
    printSignature(model);
    final String filename = "./assets/object_detection/test.jpg";
    List<Tensor<?>> outputs = null;
    try {
        Tensor<UInt8> input = makeImageTensor(filename)
        outputs =
                model
                        .session()
                        .runner()
                        .feed("image_tensor", input)
                        .fetch("detection_scores")
                        .fetch("detection_classes")
                        .fetch("detection_boxes")
                        .run();
    }  catch(Exception e){
        e.printStackTrace();
    } 
    
    try {
        Tensor<Float> scoresT = outputs.get(0).expect(Float.class);
        Tensor<Float> classesT = outputs.get(1).expect(Float.class);
        Tensor<Float> boxesT = outputs.get(2).expect(Float.class);
        int maxObjects = (int) scoresT.shape()[1];
        float[] scores = scoresT.copyTo(new float[1][maxObjects])[0];
        float[] classes = classesT.copyTo(new float[1][maxObjects])[0];
        float[][] boxes = boxesT.copyTo(new float[1][maxObjects][4])[0];
        
        System.out.printf("* %s\n", filename);
        
        BufferedImage myPicture = ImageIO.read(new File(filename));
        Graphics2D g = (Graphics2D) myPicture.getGraphics();
        g.setStroke(new BasicStroke(3));
        g.setFont(new Font(Font.MONOSPACED, Font.BOLD, 16));
        List<Color> colors = new ArrayList<>();
        colors.add(java.awt.Color.RED);
        colors.add(java.awt.Color.ORANGE);
        colors.add(java.awt.Color.BLUE);
        
        
        boolean foundSomething = false;
        // Print all objects whose score is at least 0.5.
        for (int i = 0; i < scores.length; ++i) {
            if (scores[i] < 0.5) {
                continue;
            }
            foundSomething = true;
            g.setColor(colors.get(i % colors.size()));
            g.drawRect((int) (boxes[i][1] * myPicture.getWidth()), (int) (boxes[i][0] * myPicture.getHeight()), (int) ((boxes[i][3] - boxes[i][1]) * myPicture.getWidth()), (int) ((boxes[i][2] - boxes[i][0]) * myPicture.getHeight()));
            g.drawString(labels.get((int) classes[i]) + " " + String.format("%.2f", (scores[i] * 100)) + "%", (float)(boxes[i][1] * myPicture.getWidth()), (float)(boxes[i][0] * myPicture.getHeight()));
            System.out.printf("\tFound %-20s (score: %.4f)\n", labels.get((int) classes[i]), scores[i]);
        }
        if (!foundSomething) {
            System.out.println("No objects detected with a high enough score.");
        }
        
        JLabel picLabel = new JLabel(new ImageIcon(myPicture));
        JPanel jPanel = new JPanel();
        jPanel.add(picLabel);
        JFrame f = new JFrame();
        f.setSize(new Dimension(myPicture.getWidth(), myPicture.getHeight()));
        f.add(jPanel);
        f.setVisible(false);
        ImageIO.write(myPicture, "jpg", new File("./assets/object_detection/output.jpg"));
    } catch(Exception e){
        e.printStackTrace();
    }
} catch(Exception e){
    e.printStackTrace();
}
  
private static void printSignature(SavedModelBundle model) throws Exception {
    MetaGraphDef m = MetaGraphDef.parseFrom(model.metaGraphDef());
    SignatureDef sig = m.getSignatureDefOrThrow("serving_default");
    int numInputs = sig.getInputsCount();
    int i = 1;
    System.out.println("MODEL SIGNATURE");
    System.out.println("Inputs:");
    for (Map.Entry<String, TensorInfo> entry : sig.getInputsMap().entrySet()) {
        TensorInfo t = entry.getValue();
        System.out.printf(
                "%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
                i++, numInputs, entry.getKey(), t.getName(), t.getDtype());
    }
    int numOutputs = sig.getOutputsCount();
    i = 1;
    System.out.println("Outputs:");
    for (Map.Entry<String, TensorInfo> entry : sig.getOutputsMap().entrySet()) {
        TensorInfo t = entry.getValue();
        System.out.printf(
                "%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
                i++, numOutputs, entry.getKey(), t.getName(), t.getDtype());
    }
    System.out.println("-----------------------------------------------");
}

private static void bgr2rgb(byte[] data) {
    for (int i = 0; i < data.length; i += 3) {
        byte tmp = data[i];
        data[i] = data[i + 2];
        data[i + 2] = tmp;
    }
}

private static Tensor<UInt8> makeImageTensor(String filename) throws IOException {
    BufferedImage img = ImageIO.read(new File(filename));
    if (img.getType() != BufferedImage.TYPE_3BYTE_BGR) {
        throw new IOException(
                String.format(
                        "Expected 3-byte BGR encoding in BufferedImage, found %d (file: %s). This code could be made more robust",
                        img.getType(), filename));
    }
    byte[] data = ((DataBufferByte) img.getData().getDataBuffer()).getData();
    // ImageIO.read seems to produce BGR-encoded images, but the model expects RGB.
    bgr2rgb(data);
    final long BATCH_SIZE = 1;
    final long CHANNELS = 3;
    long[] shape = [BATCH_SIZE, img.getHeight(), img.getWidth(), CHANNELS];
    return Tensor.create(UInt8.class, shape, ByteBuffer.wrap(data));
}


MODEL SIGNATURE
Inputs:
1 of 1: inputs               (Node name in graph: image_tensor:0      , type: DT_UINT8)
Outputs:
1 of 4: detection_classes    (Node name in graph: detection_classes:0 , type: DT_FLOAT)
2 of 4: num_detections       (Node name in graph: num_detections:0    , type: DT_FLOAT)
3 of 4: detection_boxes      (Node name in graph: detection_boxes:0   , type: DT_FLOAT)
4 of 4: detection_scores     (Node name in graph: detection_scores:0  , type: DT_FLOAT)
-----------------------------------------------
* ./assets/object_detection/test.jpg
	Found person               (score: 0.9877)
	Found person               (score: 0.9687)
	Found dog                  (score: 0.9489)
	Found person               (score: 0.9348)
	Found person               (score: 0.8712)
Out[6]:
true

The input and output images are as follows:

Let's test this model on other images...


In [7]:
import org.tensorflow.*;
import org.tensorflow.framework.*;
import org.tensorflow.types.UInt8;

import javax.imageio.ImageIO;
import javax.swing.*;
import java.awt.*;
import java.awt.image.*;
import java.io.*;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

final List<String> labels = new ArrayList<>();
String line;
BufferedReader reader = new BufferedReader(new FileReader("./assets/object_detection/object_label.txt"));
while ((line = reader.readLine()) != null) {
    labels.add(line);
}

try{
    SavedModelBundle model = SavedModelBundle.load("./assets/object_detection/ssd_inception_v2_coco_2017_11_17/saved_model", "serve");
    printSignature(model);
    final String filename = "./assets/object_detection/test2.jpeg";
    List<Tensor<?>> outputs = null;
    try {
        Tensor<UInt8> input = makeImageTensor(filename)
        outputs =
                model
                        .session()
                        .runner()
                        .feed("image_tensor", input)
                        .fetch("detection_scores")
                        .fetch("detection_classes")
                        .fetch("detection_boxes")
                        .run();
    }  catch(Exception e){
        e.printStackTrace();
    } 
    
    try {
        Tensor<Float> scoresT = outputs.get(0).expect(Float.class);
        Tensor<Float> classesT = outputs.get(1).expect(Float.class);
        Tensor<Float> boxesT = outputs.get(2).expect(Float.class);
        int maxObjects = (int) scoresT.shape()[1];
        float[] scores = scoresT.copyTo(new float[1][maxObjects])[0];
        float[] classes = classesT.copyTo(new float[1][maxObjects])[0];
        float[][] boxes = boxesT.copyTo(new float[1][maxObjects][4])[0];
        
        System.out.printf("* %s\n", filename);
        
        BufferedImage myPicture = ImageIO.read(new File(filename));
        Graphics2D g = (Graphics2D) myPicture.getGraphics();
        g.setStroke(new BasicStroke(3));
        g.setFont(new Font(Font.MONOSPACED, Font.BOLD, 16));
        List<Color> colors = new ArrayList<>();
        colors.add(java.awt.Color.RED);
        colors.add(java.awt.Color.ORANGE);
        colors.add(java.awt.Color.BLUE);
        
        
        boolean foundSomething = false;
        // Print all objects whose score is at least 0.5.
        for (int i = 0; i < scores.length; ++i) {
            if (scores[i] < 0.5) {
                continue;
            }
            foundSomething = true;
            g.setColor(colors.get(i % colors.size()));
            g.drawRect((int) (boxes[i][1] * myPicture.getWidth()), (int) (boxes[i][0] * myPicture.getHeight()), (int) ((boxes[i][3] - boxes[i][1]) * myPicture.getWidth()), (int) ((boxes[i][2] - boxes[i][0]) * myPicture.getHeight()));
            g.drawString(labels.get((int) classes[i]) + " " + String.format("%.2f", (scores[i] * 100)) + "%", (float)(boxes[i][1] * myPicture.getWidth()), (float)(boxes[i][0] * myPicture.getHeight()));
            System.out.printf("\tFound %-20s (score: %.4f)\n", labels.get((int) classes[i]), scores[i]);
        }
        if (!foundSomething) {
            System.out.println("No objects detected with a high enough score.");
        }
        
        JLabel picLabel = new JLabel(new ImageIcon(myPicture));
        JPanel jPanel = new JPanel();
        jPanel.add(picLabel);
        JFrame f = new JFrame();
        f.setSize(new Dimension(myPicture.getWidth(), myPicture.getHeight()));
        f.add(jPanel);
        f.setVisible(false);
        ImageIO.write(myPicture, "jpg", new File("./assets/object_detection/output2.jpg"));
    } catch(Exception e){
        e.printStackTrace();
    }
} catch(Exception e){
    e.printStackTrace();
}
  
private static void printSignature(SavedModelBundle model) throws Exception {
    MetaGraphDef m = MetaGraphDef.parseFrom(model.metaGraphDef());
    SignatureDef sig = m.getSignatureDefOrThrow("serving_default");
    int numInputs = sig.getInputsCount();
    int i = 1;
    System.out.println("MODEL SIGNATURE");
    System.out.println("Inputs:");
    for (Map.Entry<String, TensorInfo> entry : sig.getInputsMap().entrySet()) {
        TensorInfo t = entry.getValue();
        System.out.printf(
                "%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
                i++, numInputs, entry.getKey(), t.getName(), t.getDtype());
    }
    int numOutputs = sig.getOutputsCount();
    i = 1;
    System.out.println("Outputs:");
    for (Map.Entry<String, TensorInfo> entry : sig.getOutputsMap().entrySet()) {
        TensorInfo t = entry.getValue();
        System.out.printf(
                "%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
                i++, numOutputs, entry.getKey(), t.getName(), t.getDtype());
    }
    System.out.println("-----------------------------------------------");
}

private static void bgr2rgb(byte[] data) {
    for (int i = 0; i < data.length; i += 3) {
        byte tmp = data[i];
        data[i] = data[i + 2];
        data[i + 2] = tmp;
    }
}

private static Tensor<UInt8> makeImageTensor(String filename) throws IOException {
    BufferedImage img = ImageIO.read(new File(filename));
    if (img.getType() != BufferedImage.TYPE_3BYTE_BGR) {
        throw new IOException(
                String.format(
                        "Expected 3-byte BGR encoding in BufferedImage, found %d (file: %s). This code could be made more robust",
                        img.getType(), filename));
    }
    byte[] data = ((DataBufferByte) img.getData().getDataBuffer()).getData();
    // ImageIO.read seems to produce BGR-encoded images, but the model expects RGB.
    bgr2rgb(data);
    final long BATCH_SIZE = 1;
    final long CHANNELS = 3;
    long[] shape = [BATCH_SIZE, img.getHeight(), img.getWidth(), CHANNELS];
    return Tensor.create(UInt8.class, shape, ByteBuffer.wrap(data));
}


MODEL SIGNATURE
Inputs:
1 of 1: inputs               (Node name in graph: image_tensor:0      , type: DT_UINT8)
Outputs:
1 of 4: detection_boxes      (Node name in graph: detection_boxes:0   , type: DT_FLOAT)
2 of 4: detection_scores     (Node name in graph: detection_scores:0  , type: DT_FLOAT)
3 of 4: detection_classes    (Node name in graph: detection_classes:0 , type: DT_FLOAT)
4 of 4: num_detections       (Node name in graph: num_detections:0    , type: DT_FLOAT)
-----------------------------------------------
* ./assets/object_detection/test2.jpeg
	Found bus                  (score: 0.9844)
	Found truck                (score: 0.8359)
	Found person               (score: 0.7873)
	Found person               (score: 0.7810)
	Found person               (score: 0.6311)
	Found person               (score: 0.6285)
	Found truck                (score: 0.5652)
	Found person               (score: 0.5507)
Out[7]:
true

The input and output images are shown below.