sklearn-porter

Repository: https://github.com/nok/sklearn-porter

AdaBoostClassifier

Documentation: sklearn.ensemble.AdaBoostClassifier


In [1]:
import sys
sys.path.append('../../../../..')

Load data


In [2]:
from sklearn.datasets import load_iris

iris_data = load_iris()

X = iris_data.data
y = iris_data.target

print(X.shape, y.shape)


((150, 4), (150,))

Train classifier


In [3]:
from sklearn.ensemble import AdaBoostClassifier
from sklearn.tree import DecisionTreeClassifier

base_estimator = DecisionTreeClassifier(max_depth=4, random_state=0)
clf = AdaBoostClassifier(base_estimator=base_estimator, n_estimators=100,
                         random_state=0)
clf.fit(X, y)


Out[3]:
AdaBoostClassifier(algorithm='SAMME.R',
          base_estimator=DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=4,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=0,
            splitter='best'),
          learning_rate=1.0, n_estimators=100, random_state=0)

Transpile classifier


In [4]:
from sklearn_porter import Porter

porter = Porter(clf, language='java')
output = porter.export(embed_data=True)

print(output)


class AdaBoostClassifier {

    private static int findMax(double[] nums) {
        int index = 0;
        for (int i = 0; i < nums.length; i++) {
            index = nums[i] > nums[index] ? i : index;
        }
        return index;
    }
    
    private static double[] predict_0(double[] features) {
        double[] classes = new double[3];
        if (features[3] <= 0.800000011920929) {
            classes[0] = 0.333333333333333; 
            classes[1] = 0.0; 
            classes[2] = 0.0; 
        } else {
            if (features[3] <= 1.75) {
                if (features[2] <= 4.950000047683716) {
                    if (features[3] <= 1.6500000357627869) {
                        classes[0] = 0.0; 
                        classes[1] = 0.313333333333333; 
                        classes[2] = 0.0; 
                    } else {
                        classes[0] = 0.0; 
                        classes[1] = 0.0; 
                        classes[2] = 0.006666666666666667; 
                    }
                } else {
                    if (features[3] <= 1.550000011920929) {
                        classes[0] = 0.0; 
                        classes[1] = 0.0; 
                        classes[2] = 0.02; 
                    } else {
                        classes[0] = 0.0; 
                        classes[1] = 0.013333333333333334; 
                        classes[2] = 0.006666666666666667; 
                    }
                }
            } else {
                if (features[2] <= 4.8500001430511475) {
                    if (features[0] <= 5.950000047683716) {
                        classes[0] = 0.0; 
                        classes[1] = 0.006666666666666667; 
                        classes[2] = 0.0; 
                    } else {
                        classes[0] = 0.0; 
                        classes[1] = 0.0; 
                        classes[2] = 0.013333333333333334; 
                    }
                } else {
                    classes[0] = 0.0; 
                    classes[1] = 0.0; 
                    classes[2] = 0.2866666666666664; 
                }
            }
        }
        return classes;
    }
    
    private static double[] predict_1(double[] features) {
        double[] classes = new double[3];
        if (features[2] <= 5.1499998569488525) {
            if (features[2] <= 2.449999988079071) {
                classes[0] = 8.32907244640284e-05; 
                classes[1] = 0.0; 
                classes[2] = 0.0; 
            } else {
                if (features[3] <= 1.75) {
                    if (features[0] <= 4.950000047683716) {
                        classes[0] = 0.0; 
                        classes[1] = 1.6658144892805682e-06; 
                        classes[2] = 1.6658144892805682e-06; 
                    } else {
                        classes[0] = 0.0; 
                        classes[1] = 0.49995419010154496; 
                        classes[2] = 3.3316289785611363e-06; 
                    }
                } else {
                    if (features[1] <= 3.149999976158142) {
                        classes[0] = 0.0; 
                        classes[1] = 0.0; 
                        classes[2] = 1.9989773871366814e-05; 
                    } else {
                        classes[0] = 0.0; 
                        classes[1] = 1.6658144892805682e-06; 
                        classes[2] = 1.6658144892805682e-06; 
                    }
                }
            }
        } else {
            classes[0] = 0.0; 
            classes[1] = 0.0; 
            classes[2] = 0.4999325345131842; 
        }
        return classes;
    }
    
    private static double[] predict_2(double[] features) {
        double[] classes = new double[3];
        if (features[3] <= 1.550000011920929) {
            if (features[2] <= 4.950000047683716) {
                if (features[3] <= 0.800000011920929) {
                    classes[0] = 2.6788177186451792e-08; 
                    classes[1] = 0.0; 
                    classes[2] = 0.0; 
                } else {
                    classes[0] = 0.0; 
                    classes[1] = 0.00018473109499329488; 
                    classes[2] = 0.0; 
                }
            } else {
                classes[0] = 0.0; 
                classes[1] = 0.0; 
                classes[2] = 0.49969664310232625; 
            }
        } else {
            if (features[2] <= 5.1499998569488525) {
                if (features[3] <= 1.8499999642372131) {
                    if (features[0] <= 5.400000095367432) {
                        classes[0] = 0.0; 
                        classes[1] = 0.0; 
                        classes[2] = 0.00011147301524887026; 
                    } else {
                        classes[0] = 0.0; 
                        classes[1] = 0.49973485750206614; 
                        classes[2] = 2.6788177186451756e-09; 
                    }
                } else {
                    classes[0] = 0.0; 
                    classes[1] = 0.0; 
                    classes[2] = 0.00011147676559367639; 
                }
            } else {
                classes[0] = 0.0; 
                classes[1] = 0.0; 
                classes[2] = 0.00016078905277695348; 
            }
        }
        return classes;
    }
    
    private static double[] predict_3(double[] features) {
        double[] classes = new double[3];
        if (features[3] <= 1.75) {
            if (features[3] <= 1.550000011920929) {
                if (features[2] <= 4.950000047683716) {
                    if (features[3] <= 0.800000011920929) {
                        classes[0] = 9.257653973762734e-11; 
                        classes[1] = 0.0; 
                        classes[2] = 0.0; 
                    } else {
                        classes[0] = 0.0; 
                        classes[1] = 6.384072136521275e-07; 
                        classes[2] = 0.0; 
                    }
                } else {
                    classes[0] = 0.0; 
                    classes[1] = 0.0; 
                    classes[2] = 0.0017268881646907192; 
                }
            } else {
                if (features[0] <= 6.949999809265137) {
                    if (features[1] <= 2.600000023841858) {
                        classes[0] = 0.0; 
                        classes[1] = 0.0; 
                        classes[2] = 3.852365897848193e-07; 
                    } else {
                        classes[0] = 0.0; 
                        classes[1] = 0.4990242342550203; 
                        classes[2] = 0.0; 
                    }
                } else {
                    classes[0] = 0.0; 
                    classes[1] = 0.0; 
                    classes[2] = 5.556073060838475e-07; 
                }
            }
        } else {
            if (features[1] <= 3.149999976158142) {
                classes[0] = 0.0; 
                classes[1] = 0.0; 
                classes[2] = 0.49913557364140265; 
            } else {
                if (features[2] <= 4.950000047683716) {
                    classes[0] = 0.0; 
                    classes[1] = 0.00011133933639195673; 
                    classes[2] = 0.0; 
                } else {
                    classes[0] = 0.0; 
                    classes[1] = 0.0; 
                    classes[2] = 3.852588081543566e-07; 
                }
            }
        }
        return classes;
    }
    
    public static int predict(double[] features) {
        int n_estimators = 4;
        int n_classes = 3;
    
        double[][] preds = new double[n_estimators][];
        preds[0] = AdaBoostClassifier.predict_0(features);
        preds[1] = AdaBoostClassifier.predict_1(features);
        preds[2] = AdaBoostClassifier.predict_2(features);
        preds[3] = AdaBoostClassifier.predict_3(features);
    
        int i, j;
        double normalizer, sum;
        for (i = 0; i < n_estimators; i++) {
            normalizer = 0.;
            for (j = 0; j < n_classes; j++) {
                normalizer += preds[i][j];
            }
            if (normalizer == 0.) {
                normalizer = 1.;
            }
            for (j = 0; j < n_classes; j++) {
                preds[i][j] = preds[i][j] / normalizer;
                if (preds[i][j] <= 2.2204460492503131e-16) {
                    preds[i][j] = 2.2204460492503131e-16;
                }
                preds[i][j] = Math.log(preds[i][j]);
            }
            sum = 0.;
            for (j = 0; j < n_classes; j++) {
                sum += preds[i][j];
            }
            for (j = 0; j < n_classes; j++) {
                preds[i][j] = (n_classes - 1) * (preds[i][j] - (1. / n_classes) * sum);
            }
        }
        double[] classes = new double[n_classes];
        for (i = 0; i < n_estimators; i++) {
            for (j = 0; j < n_classes; j++) {
                classes[j] += preds[i][j];
            }
        }
    
        return AdaBoostClassifier.findMax(classes);
    }

    public static void main(String[] args) {
        if (args.length == 4) {

            // Features:
            double[] features = new double[args.length];
            for (int i = 0, l = args.length; i < l; i++) {
                features[i] = Double.parseDouble(args[i]);
            }

            // Prediction:
            int prediction = AdaBoostClassifier.predict(features);
            System.out.println(prediction);

        }
    }

}

Run classification in Java


In [5]:
# Save classifier:
# with open('AdaBoostClassifier.java', 'w') as f:
#     f.write(output)

# Compile model:
# $ javac -cp . AdaBoostClassifier.java

# Run classification:
# $ java AdaBoostClassifier 1 2 3 4