sklearn-porter

Repository: https://github.com/nok/sklearn-porter

DecisionTreeClassifier

Documentation: sklearn.tree.DecisionTreeClassifier


In [1]:
import sys
sys.path.append('../../../../..')

Load data


In [2]:
from sklearn.datasets import load_iris

iris_data = load_iris()

X = iris_data.data
y = iris_data.target

print(X.shape, y.shape)


((150, 4), (150,))

Train classifier


In [3]:
from sklearn.tree import tree

clf = tree.DecisionTreeClassifier()
clf.fit(X, y)


Out[3]:
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best')

Transpile classifier


In [4]:
from sklearn_porter import Porter

porter = Porter(clf, language='java')
output = porter.export(embed_data=True)

print(output)

# class DecisionTreeClassifier {
#
#     private static int findMax(int[] nums) {
#         int index = 0;
#         for (int i = 0; i < nums.length; i++) {
#             index = nums[i] > nums[index] ? i : index;
#         }
#         return index;
#     }
#
#     public static int predict(double[] features) {
#         int[] classes = new int[3];
#
#         if (features[3] <= 0.800000011920929) {
#             classes[0] = 50;
#             classes[1] = 0;
#             classes[2] = 0;
#         } else {
#             if (features[3] <= 1.75) {
#                 if (features[2] <= 4.950000047683716) {
#                     if (features[3] <= 1.6500000357627869) {
#                         classes[0] = 0;
#                         classes[1] = 47;
#                         classes[2] = 0;
#                     } else {
#                         classes[0] = 0;
#                         classes[1] = 0;
#                         classes[2] = 1;
#                     }
#                 } else {
#                     if (features[3] <= 1.550000011920929) {
#                         classes[0] = 0;
#                         classes[1] = 0;
#                         classes[2] = 3;
#                     } else {
#                         if (features[2] <= 5.450000047683716) {
#                             classes[0] = 0;
#                             classes[1] = 2;
#                             classes[2] = 0;
#                         } else {
#                             classes[0] = 0;
#                             classes[1] = 0;
#                             classes[2] = 1;
#                         }
#                     }
#                 }
#             } else {
#                 if (features[2] <= 4.8500001430511475) {
#                     if (features[0] <= 5.950000047683716) {
#                         classes[0] = 0;
#                         classes[1] = 1;
#                         classes[2] = 0;
#                     } else {
#                         classes[0] = 0;
#                         classes[1] = 0;
#                         classes[2] = 2;
#                     }
#                 } else {
#                     classes[0] = 0;
#                     classes[1] = 0;
#                     classes[2] = 43;
#                 }
#             }
#         }
#
#         return findMax(classes);
#     }
#
#     public static void main(String[] args) {
#         if (args.length == 4) {
#
#             // Features:
#             double[] features = new double[args.length];
#             for (int i = 0, l = args.length; i < l; i++) {
#                 features[i] = Double.parseDouble(args[i]);
#             }
#
#             // Prediction:
#             int prediction = DecisionTreeClassifier.predict(features);
#             System.out.println(prediction);
#
#         }
#     }
# }


class DecisionTreeClassifier {

    private static int findMax(int[] nums) {
        int index = 0;
        for (int i = 0; i < nums.length; i++) {
            index = nums[i] > nums[index] ? i : index;
        }
        return index;
    }

    public static int predict(double[] features) {
        int[] classes = new int[3];
            
        if (features[3] <= 0.800000011920929) {
            classes[0] = 50; 
            classes[1] = 0; 
            classes[2] = 0; 
        } else {
            if (features[3] <= 1.75) {
                if (features[2] <= 4.950000047683716) {
                    if (features[3] <= 1.6500000357627869) {
                        classes[0] = 0; 
                        classes[1] = 47; 
                        classes[2] = 0; 
                    } else {
                        classes[0] = 0; 
                        classes[1] = 0; 
                        classes[2] = 1; 
                    }
                } else {
                    if (features[3] <= 1.550000011920929) {
                        classes[0] = 0; 
                        classes[1] = 0; 
                        classes[2] = 3; 
                    } else {
                        if (features[0] <= 6.949999809265137) {
                            classes[0] = 0; 
                            classes[1] = 2; 
                            classes[2] = 0; 
                        } else {
                            classes[0] = 0; 
                            classes[1] = 0; 
                            classes[2] = 1; 
                        }
                    }
                }
            } else {
                if (features[2] <= 4.8500001430511475) {
                    if (features[0] <= 5.950000047683716) {
                        classes[0] = 0; 
                        classes[1] = 1; 
                        classes[2] = 0; 
                    } else {
                        classes[0] = 0; 
                        classes[1] = 0; 
                        classes[2] = 2; 
                    }
                } else {
                    classes[0] = 0; 
                    classes[1] = 0; 
                    classes[2] = 43; 
                }
            }
        }
    
        return findMax(classes);
    }

    public static void main(String[] args) {
        if (args.length == 4) {

            // Features:
            double[] features = new double[args.length];
            for (int i = 0, l = args.length; i < l; i++) {
                features[i] = Double.parseDouble(args[i]);
            }

            // Prediction:
            int prediction = DecisionTreeClassifier.predict(features);
            System.out.println(prediction);

        }
    }
}

Run classification in Java


In [5]:
# Save classifier:
# with open('DecisionTreeClassifier.java', 'w') as f:
#     f.write(output)

# Compile model:
# $ javac -cp . DecisionTreeClassifier.java

# Run classification:
# $ java DecisionTreeClassifier 1 2 3 4