sklearn-porter

Repository: https://github.com/nok/sklearn-porter

MLPClassifier

Documentation: sklearn.neural_network.MLPClassifier


In [1]:
import sys
sys.path.append('../../../../..')

Load data


In [2]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle

iris_data = load_iris()
X = iris_data.data
y = iris_data.target

X = shuffle(X, random_state=0)
y = shuffle(y, random_state=0)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.4, random_state=5)

print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)


((90, 4), (90,))
((60, 4), (60,))

Train classifier


In [3]:
from sklearn.neural_network import MLPClassifier

clf = MLPClassifier(activation='relu', hidden_layer_sizes=50,
                    max_iter=500, alpha=1e-4, solver='sgd',
                    tol=1e-4, random_state=1, learning_rate_init=.1)
clf.fit(X_train, y_train)


Out[3]:
MLPClassifier(activation='relu', alpha=0.0001, batch_size='auto', beta_1=0.9,
       beta_2=0.999, early_stopping=False, epsilon=1e-08,
       hidden_layer_sizes=50, learning_rate='constant',
       learning_rate_init=0.1, max_iter=500, momentum=0.9,
       n_iter_no_change=10, nesterovs_momentum=True, power_t=0.5,
       random_state=1, shuffle=True, solver='sgd', tol=0.0001,
       validation_fraction=0.1, verbose=False, warm_start=False)

Transpile classifier


In [4]:
from sklearn_porter import Porter

porter = Porter(clf, language='java')
output = porter.export(export_data=True)

print(output)


import java.io.File;
import java.io.FileNotFoundException;
import java.util.*;
import com.google.gson.Gson;


class MLPClassifier {

    private enum Activation { IDENTITY, LOGISTIC, RELU, TANH, SOFTMAX }

    private class Classifier {
        private String hidden_activation;
        private Activation hidden;
        private String output_activation;
        private Activation output;
        private double[][] network;
        private double[][][] weights;
        private double[][] bias;
        private int[] layers;
    }

    private Classifier clf;

    public MLPClassifier(String file) throws FileNotFoundException {
        String jsonStr = new Scanner(new File(file)).useDelimiter("\\Z").next();
        this.clf = new Gson().fromJson(jsonStr, Classifier.class);
        this.clf.network = new double[this.clf.layers.length + 1][];
        for (int i = 0, l = this.clf.layers.length; i < l; i++) {
            this.clf.network[i + 1] = new double[this.clf.layers[i]];
        }
        this.clf.hidden = Activation.valueOf(this.clf.hidden_activation.toUpperCase());
        this.clf.output = Activation.valueOf(this.clf.output_activation.toUpperCase());
    }

    private double[] compute(Activation activation, double[] v) {
        switch (activation) {
            case LOGISTIC:
                for (int i = 0, l = v.length; i < l; i++) {
                    v[i] = 1. / (1. + Math.exp(-v[i]));
                }
                break;
            case RELU:
                for (int i = 0, l = v.length; i < l; i++) {
                    v[i] = Math.max(0, v[i]);
                }
                break;
            case TANH:
                for (int i = 0, l = v.length; i < l; i++) {
                    v[i] = Math.tanh(v[i]);
                }
                break;
            case SOFTMAX:
                double max = Double.NEGATIVE_INFINITY;
                for (double x : v) {
                    if (x > max) {
                        max = x;
                    }
                }
                for (int i = 0, l = v.length; i < l; i++) {
                    v[i] = Math.exp(v[i] - max);
                }
                double sum = 0.;
                for (double x : v) {
                    sum += x;
                }
                for (int i = 0, l = v.length; i < l; i++) {
                    v[i] /= sum;
                }
                break;
        }
        return v;
    }

    public int predict(double[] neurons) {
        this.clf.network[0] = neurons;

        for (int i = 0; i < this.clf.network.length - 1; i++) {
            for (int j = 0; j < this.clf.network[i + 1].length; j++) {
                for (int l = 0; l < this.clf.network[i].length; l++) {
                    this.clf.network[i + 1][j] += this.clf.network[i][l] * this.clf.weights[i][l][j];
                }
                this.clf.network[i + 1][j] += this.clf.bias[i][j];
            }
            if ((i + 1) < (this.clf.network.length - 1)) {
                this.clf.network[i + 1] = this.compute(this.clf.hidden, this.clf.network[i + 1]);
            }
        }
        this.clf.network[this.clf.network.length - 1] = this.compute(this.clf.output, this.clf.network[this.clf.network.length - 1]);

        if (this.clf.network[this.clf.network.length - 1].length == 1) {
            if (this.clf.network[this.clf.network.length - 1][0] > .5) {
                return 1;
            }
            return 0;
        } else {
            int classIdx = 0;
            for (int i = 0; i < this.clf.network[this.clf.network.length - 1].length; i++) {
                classIdx = this.clf.network[this.clf.network.length - 1][i] > this.clf.network[this.clf.network.length - 1][classIdx] ? i : classIdx;
            }
            return classIdx;
        }
    }

    public static void main(String[] args) throws FileNotFoundException {
        if (args.length > 0 && args[0].endsWith(".json")) {

            // Features:
            double[] features = new double[args.length-1];
            for (int i = 1, l = args.length; i < l; i++) {
                features[i - 1] = Double.parseDouble(args[i]);
            }

            // Parameters:
            String modelData = args[0];

            // Estimators:
            MLPClassifier clf = new MLPClassifier(modelData);

            // Prediction:
            int prediction = clf.predict(features);
            System.out.println(prediction);

        }
    }
}

Run classification in Java


In [5]:
# Save classifier:
# with open('MLPClassifier.java', 'w') as f:
#     f.write(output)

# Check model data:
# $ cat data.json

# Download dependencies:
# $ wget -O gson.jar http://central.maven.org/maven2/com/google/code/gson/gson/2.8.5/gson-2.8.5.jar

# Compile model:
# $ javac -cp .:gson.jar MLPClassifier.java

# Run classification:
# $ java -cp .:gson.jar MLPClassifier data.json 1 2 3 4