sklearn-porter

Repository: https://github.com/nok/sklearn-porter

BernoulliNB

Documentation: sklearn.naive_bayes.BernoulliNB


In [1]:
import sys
sys.path.append('../../../../..')

Load data


In [2]:
from sklearn.datasets import load_iris

iris_data = load_iris()

X = iris_data.data
y = iris_data.target

print(X.shape, y.shape)


((150, 4), (150,))

Train classifier


In [3]:
from sklearn.naive_bayes import BernoulliNB

clf = BernoulliNB()
clf.fit(X, y)


Out[3]:
BernoulliNB(alpha=1.0, binarize=0.0, class_prior=None, fit_prior=True)

Transpile classifier


In [4]:
from sklearn_porter import Porter

porter = Porter(clf, language='js')
output = porter.export()

print(output)


var BernoulliNB = function(priors, negProbs, delProbs) {

    this.priors = priors;
    this.negProbs = negProbs;
    this.delProbs = delProbs;

    this.predict = function(features) {
        var nClasses = this.priors.length,
            nFeatures = this.delProbs.length;
    
        var jll = new Array(nClasses);
        for (var i = 0; i < nClasses; i++) {
            var sum = 0.;
            for (var j = 0; j < nFeatures; j++) {
                sum += features[i] * this.delProbs[j][i];
            }
            jll[i] = sum;
        }
        for (var i = 0; i < nClasses; i++) {
            var sum = 0.;
            for (var j = 0; j < nFeatures; j++) {
                sum += this.negProbs[i][j];
            }
            jll[i] += this.priors[i] + sum;
        }
        var classIdx = 0;
    
        for (var i = 0; i < nClasses; i++) {
            classIdx = jll[i] > jll[classIdx] ? i : classIdx;
        }
        return classIdx;
    };

};

if (typeof process !== 'undefined' && typeof process.argv !== 'undefined') {
    if (process.argv.length - 2 === 4) {

        // Features:
        var features = process.argv.slice(2);

        // Parameters:
        var priors = [-1.0986122886681096, -1.0986122886681096, -1.0986122886681096];
        var negProbs = [[-3.9512437185814138, -3.9512437185814138, -3.9512437185814138, -3.9512437185814138], [-3.9512437185814138, -3.9512437185814138, -3.9512437185814138, -3.9512437185814138], [-3.9512437185814138, -3.9512437185814138, -3.9512437185814138, -3.9512437185814138]];
        var delProbs = [[3.931825632724312, 3.931825632724312, 3.931825632724312], [3.931825632724312, 3.931825632724312, 3.931825632724312], [3.931825632724312, 3.931825632724312, 3.931825632724312], [3.931825632724312, 3.931825632724312, 3.931825632724312]];

        // Estimator:
        var clf = new BernoulliNB(priors, negProbs, delProbs);
        var prediction = clf.predict(features);
        console.log(prediction);

    }
}

Run classification in JavaScript


In [5]:
# Save classifier:
# with open('BernoulliNB.js', 'w') as f:
#     f.write(output)

# Run classification:
# if hash node 2/dev/null; then
#     node BernoulliNB.js 1 2 3 4
# fi