sklearn-porter

Repository: https://github.com/nok/sklearn-porter

GaussianNB

Documentation: sklearn.naive_bayes.GaussianNB


In [1]:
import sys
sys.path.append('../../../../..')

Load data


In [2]:
from sklearn.datasets import load_iris

iris_data = load_iris()

X = iris_data.data
y = iris_data.target

print(X.shape, y.shape)


((150, 4), (150,))

Train classifier


In [3]:
from sklearn.naive_bayes import GaussianNB

clf = GaussianNB()
clf.fit(X, y)


Out[3]:
GaussianNB(priors=None, var_smoothing=1e-09)

Transpile classifier


In [4]:
from sklearn_porter import Porter

porter = Porter(clf, language='java')
output = porter.export(export_data=True)

print(output)


import java.io.File;
import java.io.FileNotFoundException;
import java.util.*;
import com.google.gson.Gson;


class GaussianNB {

    private class Classifier {
        private double[] priors;
        private double[][] sigmas;
        private double[][] thetas;
    }

    private Classifier clf;

    public GaussianNB(String file) throws FileNotFoundException {
        String jsonStr = new Scanner(new File(file)).useDelimiter("\\Z").next();
        this.clf = new Gson().fromJson(jsonStr, Classifier.class);
    }

    public int predict(double[] features) {
        double[] likelihoods = new double[this.clf.sigmas.length];

        for (int i = 0, il = this.clf.sigmas.length; i < il; i++) {
            double sum = 0.;
            for (int j = 0, jl = this.clf.sigmas[0].length; j < jl; j++) {
                sum += Math.log(2. * Math.PI * this.clf.sigmas[i][j]);
            }
            double nij = -0.5 * sum;
            sum = 0.;
            for (int j = 0, jl = this.clf.sigmas[0].length; j < jl; j++) {
                sum += Math.pow(features[j] - this.clf.thetas[i][j], 2.) / this.clf.sigmas[i][j];
            }
            nij -= 0.5 * sum;
            likelihoods[i] = Math.log(this.clf.priors[i]) + nij;
        }

        int classIdx = 0;
        for (int i = 0, l = likelihoods.length; i < l; i++) {
            classIdx = likelihoods[i] > likelihoods[classIdx] ? i : classIdx;
        }
        return classIdx;
    }

    public static void main(String[] args) throws FileNotFoundException {
        if (args.length > 0 && args[0].endsWith(".json")) {

            // Features:
            double[] features = new double[args.length-1];
            for (int i = 1, l = args.length; i < l; i++) {
                features[i - 1] = Double.parseDouble(args[i]);
            }

            // Parameters:
            String modelData = args[0];

            // Estimators:
            GaussianNB clf = new GaussianNB(modelData);

            // Prediction:
            int prediction = clf.predict(features);
            System.out.println(prediction);

        }
    }
}

Run classification in Java


In [5]:
# Save classifier:
# with open('GaussianNB.java', 'w') as f:
#     f.write(output)

# Check model data:
# $ cat data.json

# Download dependencies:
# $ wget -O gson.jar http://central.maven.org/maven2/com/google/code/gson/gson/2.8.5/gson-2.8.5.jar

# Compile model:
# $ javac -cp .:gson.jar GaussianNB.java

# Run classification:
# $ java -cp .:gson.jar GaussianNB data.json 1 2 3 4