sklearn-porter

Repository: https://github.com/nok/sklearn-porter

AdaBoostClassifier

Documentation: sklearn.ensemble.AdaBoostClassifier


In [1]:
import sys
sys.path.append('../../../../..')

Load data


In [2]:
from sklearn.datasets import load_iris

iris_data = load_iris()

X = iris_data.data
y = iris_data.target

print(X.shape, y.shape)


((150, 4), (150,))

Train classifier


In [3]:
from sklearn.ensemble import AdaBoostClassifier
from sklearn.tree import DecisionTreeClassifier

base_estimator = DecisionTreeClassifier(max_depth=4, random_state=0)
clf = AdaBoostClassifier(base_estimator=base_estimator, n_estimators=100,
                         random_state=0)
clf.fit(X, y)


Out[3]:
AdaBoostClassifier(algorithm='SAMME.R',
          base_estimator=DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=4,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=0,
            splitter='best'),
          learning_rate=1.0, n_estimators=100, random_state=0)

Transpile classifier


In [4]:
from sklearn_porter import Porter

porter = Porter(clf, language='js')
output = porter.export(export_data=True)

print(output)


if (typeof XMLHttpRequest === 'undefined') {
    var XMLHttpRequest = require("xmlhttprequest").XMLHttpRequest;
}

var AdaBoostClassifier = function(jsonFile) {
    this.forest = undefined;

    var promise = new Promise(function(resolve, reject) {
        var httpRequest = new XMLHttpRequest();
        httpRequest.onreadystatechange = function() {
            if (httpRequest.readyState === 4) {
                if (httpRequest.status === 200) {
                    resolve(JSON.parse(httpRequest.responseText));
                } else {
                    reject(new Error(httpRequest.statusText));
                }
            }
        };
        httpRequest.open('GET', jsonFile, true);
        httpRequest.send();
    });

    var imax = function(nums) {
        var index = 0;
        for (var i=0, l=nums.length; i < l; i++) {
            index = nums[i] > nums[index] ? i : index;
        }
        return index;
    };

    var predict = function(tree, features, node) {
        if (tree['thresholds'][node] != -2) {
            if (features[tree['indices'][node]] <= tree['thresholds'][node]) {
                return predict(tree, features, tree['childrenLeft'][node]);
            } else {
                return predict(tree, features, tree['childrenRight'][node]);
            }
        }
        return tree['classes'][node].slice();
    };

    this.predict = function(features) {
        return new Promise(function(resolve, reject) {
            promise.then(function(forest) {
                if (typeof this.forest === 'undefined') {
                    this.forest = forest;
                    this.nEstimators = this.forest.length;
                    this.nClasses = this.forest[0]['classes'][0].length;
                }

                var preds = new Array(this.nEstimators).fill(new Array(this.nClasses).fill(0.));
                var i, j;
                for (i=0; i < this.nEstimators; i++) {
                    preds[i] = predict(this.forest[i], features, 0);
                }
                var classes = new Array(this.nClasses).fill(0.);
                for (i=0; i < this.nEstimators; i++) {
                    var normalizer = 0.;
                    for (j=0; j < this.nClasses; j++) {
                        normalizer += preds[i][j];
                    }
                    if (normalizer == 0.) {
                        normalizer = 1;
                    }
                    for (j = 0; j < this.nClasses; j++) {
                        preds[i][j] = preds[i][j] / normalizer;
                        if (preds[i][j] <= 2.2204460492503131e-16) {
                            preds[i][j] = 2.2204460492503131e-16;
                        }
                        preds[i][j] = Math.log(preds[i][j]);
                    }
                    var sum = 0.;
                    for (j = 0; j < this.nClasses; j++) {
                        sum += preds[i][j];
                    }
                    for (j = 0; j < this.nClasses; j++) {
                        preds[i][j] = (this.nClasses - 1) * (preds[i][j] - (1. / this.nClasses) * sum);
                    }

                    for (j = 0; j < this.nClasses; j++) {
                        classes[j] += preds[i][j];
                    }
                }
                resolve(imax(classes));
            }, function(error) {
                reject(error);
            });
        });
    };

};

if (typeof process !== 'undefined' && typeof process.argv !== 'undefined') {
    if (process.argv[2].trim().endsWith('.json')) {

        // Features:
        var features = process.argv.slice(3);

        // Parameters:
        var json = process.argv[2];

        // Estimator:
        var clf = new AdaBoostClassifier(json);

        // Prediction:
        clf.predict(features).then(function(prediction) {
            console.log(prediction);
        }, function(error) {
            console.log(error);
        });

    }
}

Run classification in JavaScript


In [5]:
# Save classifier:
# with open('AdaBoostClassifier.js', 'w') as f:
#     f.write(output)

# Check model data:
# $ cat data.json

# Run classification:
# if hash node 2/dev/null; then
#     python -m SimpleHTTPServer 8877 & serve_pid=$!
#     node AdaBoostClassifier.js http://127.0.0.1:8877/data.json 1 2 3 4
#     kill $serve_pid
# fi