Repository: https://github.com/nok/sklearn-porter
Documentation: sklearn.ensemble.RandomForestClassifier
In [1]:
import sys
sys.path.append('../../../../..')
In [2]:
from sklearn.datasets import load_iris
iris_data = load_iris()
X = iris_data.data
y = iris_data.target
print(X.shape, y.shape)
((150, 4), (150,))
In [3]:
from sklearn.ensemble import RandomForestClassifier
clf = RandomForestClassifier(n_estimators=15, max_depth=None,
min_samples_split=2, random_state=0)
clf.fit(X, y)
Out[3]:
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
max_depth=None, max_features='auto', max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, n_estimators=15, n_jobs=None,
oob_score=False, random_state=0, verbose=0, warm_start=False)
In [4]:
from sklearn_porter import Porter
porter = Porter(clf, language='java')
output = porter.export(embed_data=True)
print(output)
class RandomForestClassifier {
public static int predict_0(double[] features) {
int[] classes = new int[3];
if (features[3] <= 0.75) {
classes[0] = 47;
classes[1] = 0;
classes[2] = 0;
} else {
if (features[2] <= 4.8500001430511475) {
if (features[3] <= 1.6500000357627869) {
classes[0] = 0;
classes[1] = 42;
classes[2] = 0;
} else {
if (features[1] <= 3.0) {
classes[0] = 0;
classes[1] = 0;
classes[2] = 3;
} else {
classes[0] = 0;
classes[1] = 1;
classes[2] = 0;
}
}
} else {
if (features[0] <= 6.599999904632568) {
classes[0] = 0;
classes[1] = 0;
classes[2] = 27;
} else {
if (features[2] <= 5.200000047683716) {
classes[0] = 0;
classes[1] = 1;
classes[2] = 0;
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 29;
}
}
}
}
int class_idx = 0;
int class_val = classes[0];
for (int i = 1; i < 3; i++) {
if (classes[i] > class_val) {
class_idx = i;
class_val = classes[i];
}
}
return class_idx;
}
public static int predict_1(double[] features) {
int[] classes = new int[3];
if (features[3] <= 0.800000011920929) {
classes[0] = 46;
classes[1] = 0;
classes[2] = 0;
} else {
if (features[3] <= 1.75) {
if (features[2] <= 4.950000047683716) {
classes[0] = 0;
classes[1] = 58;
classes[2] = 0;
} else {
if (features[2] <= 5.450000047683716) {
if (features[1] <= 2.450000047683716) {
classes[0] = 0;
classes[1] = 0;
classes[2] = 2;
} else {
classes[0] = 0;
classes[1] = 3;
classes[2] = 0;
}
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 3;
}
}
} else {
if (features[2] <= 4.8500001430511475) {
if (features[1] <= 3.100000023841858) {
classes[0] = 0;
classes[1] = 0;
classes[2] = 2;
} else {
classes[0] = 0;
classes[1] = 1;
classes[2] = 0;
}
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 35;
}
}
}
int class_idx = 0;
int class_val = classes[0];
for (int i = 1; i < 3; i++) {
if (classes[i] > class_val) {
class_idx = i;
class_val = classes[i];
}
}
return class_idx;
}
public static int predict_2(double[] features) {
int[] classes = new int[3];
if (features[0] <= 5.549999952316284) {
if (features[3] <= 0.800000011920929) {
classes[0] = 49;
classes[1] = 0;
classes[2] = 0;
} else {
if (features[3] <= 1.600000023841858) {
classes[0] = 0;
classes[1] = 12;
classes[2] = 0;
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 1;
}
}
} else {
if (features[3] <= 1.550000011920929) {
if (features[3] <= 0.7500000149011612) {
classes[0] = 2;
classes[1] = 0;
classes[2] = 0;
} else {
if (features[2] <= 5.0) {
classes[0] = 0;
classes[1] = 32;
classes[2] = 0;
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 1;
}
}
} else {
if (features[2] <= 4.650000095367432) {
classes[0] = 0;
classes[1] = 1;
classes[2] = 0;
} else {
if (features[3] <= 1.699999988079071) {
if (features[2] <= 5.450000047683716) {
classes[0] = 0;
classes[1] = 1;
classes[2] = 0;
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 3;
}
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 48;
}
}
}
}
int class_idx = 0;
int class_val = classes[0];
for (int i = 1; i < 3; i++) {
if (classes[i] > class_val) {
class_idx = i;
class_val = classes[i];
}
}
return class_idx;
}
public static int predict_3(double[] features) {
int[] classes = new int[3];
if (features[0] <= 5.450000047683716) {
if (features[1] <= 2.8000000715255737) {
if (features[1] <= 2.450000047683716) {
classes[0] = 0;
classes[1] = 5;
classes[2] = 0;
} else {
if (features[0] <= 5.0) {
classes[0] = 0;
classes[1] = 0;
classes[2] = 3;
} else {
classes[0] = 0;
classes[1] = 3;
classes[2] = 0;
}
}
} else {
classes[0] = 41;
classes[1] = 0;
classes[2] = 0;
}
} else {
if (features[0] <= 6.25) {
if (features[3] <= 1.699999988079071) {
if (features[3] <= 0.6000000014901161) {
classes[0] = 3;
classes[1] = 0;
classes[2] = 0;
} else {
if (features[1] <= 2.25) {
if (features[3] <= 1.25) {
classes[0] = 0;
classes[1] = 1;
classes[2] = 0;
} else {
if (features[2] <= 4.75) {
classes[0] = 0;
classes[1] = 3;
classes[2] = 0;
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 1;
}
}
} else {
classes[0] = 0;
classes[1] = 37;
classes[2] = 0;
}
}
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 8;
}
} else {
if (features[2] <= 4.950000047683716) {
classes[0] = 0;
classes[1] = 10;
classes[2] = 0;
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 35;
}
}
}
int class_idx = 0;
int class_val = classes[0];
for (int i = 1; i < 3; i++) {
if (classes[i] > class_val) {
class_idx = i;
class_val = classes[i];
}
}
return class_idx;
}
public static int predict_4(double[] features) {
int[] classes = new int[3];
if (features[3] <= 0.7000000029802322) {
classes[0] = 50;
classes[1] = 0;
classes[2] = 0;
} else {
if (features[3] <= 1.75) {
if (features[2] <= 5.049999952316284) {
if (features[2] <= 4.950000047683716) {
classes[0] = 0;
classes[1] = 56;
classes[2] = 0;
} else {
if (features[3] <= 1.600000023841858) {
classes[0] = 0;
classes[1] = 0;
classes[2] = 1;
} else {
classes[0] = 0;
classes[1] = 3;
classes[2] = 0;
}
}
} else {
if (features[0] <= 6.049999952316284) {
classes[0] = 0;
classes[1] = 2;
classes[2] = 0;
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 5;
}
}
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 33;
}
}
int class_idx = 0;
int class_val = classes[0];
for (int i = 1; i < 3; i++) {
if (classes[i] > class_val) {
class_idx = i;
class_val = classes[i];
}
}
return class_idx;
}
public static int predict_5(double[] features) {
int[] classes = new int[3];
if (features[3] <= 0.800000011920929) {
classes[0] = 49;
classes[1] = 0;
classes[2] = 0;
} else {
if (features[2] <= 4.950000047683716) {
if (features[0] <= 4.950000047683716) {
if (features[3] <= 1.350000023841858) {
classes[0] = 0;
classes[1] = 1;
classes[2] = 0;
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 1;
}
} else {
if (features[2] <= 4.75) {
classes[0] = 0;
classes[1] = 49;
classes[2] = 0;
} else {
if (features[1] <= 2.600000023841858) {
classes[0] = 0;
classes[1] = 1;
classes[2] = 0;
} else {
if (features[0] <= 6.049999952316284) {
classes[0] = 0;
classes[1] = 1;
classes[2] = 0;
} else {
if (features[3] <= 1.5999999642372131) {
classes[0] = 0;
classes[1] = 1;
classes[2] = 0;
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 3;
}
}
}
}
}
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 44;
}
}
int class_idx = 0;
int class_val = classes[0];
for (int i = 1; i < 3; i++) {
if (classes[i] > class_val) {
class_idx = i;
class_val = classes[i];
}
}
return class_idx;
}
public static int predict_6(double[] features) {
int[] classes = new int[3];
if (features[3] <= 0.7000000029802322) {
classes[0] = 46;
classes[1] = 0;
classes[2] = 0;
} else {
if (features[2] <= 4.75) {
if (features[0] <= 4.950000047683716) {
classes[0] = 0;
classes[1] = 0;
classes[2] = 2;
} else {
classes[0] = 0;
classes[1] = 39;
classes[2] = 0;
}
} else {
if (features[2] <= 5.1499998569488525) {
if (features[0] <= 6.599999904632568) {
if (features[3] <= 1.699999988079071) {
if (features[3] <= 1.550000011920929) {
classes[0] = 0;
classes[1] = 0;
classes[2] = 2;
} else {
classes[0] = 0;
classes[1] = 1;
classes[2] = 0;
}
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 19;
}
} else {
classes[0] = 0;
classes[1] = 3;
classes[2] = 0;
}
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 38;
}
}
}
int class_idx = 0;
int class_val = classes[0];
for (int i = 1; i < 3; i++) {
if (classes[i] > class_val) {
class_idx = i;
class_val = classes[i];
}
}
return class_idx;
}
public static int predict_7(double[] features) {
int[] classes = new int[3];
if (features[2] <= 2.599999964237213) {
classes[0] = 58;
classes[1] = 0;
classes[2] = 0;
} else {
if (features[2] <= 4.75) {
classes[0] = 0;
classes[1] = 37;
classes[2] = 0;
} else {
if (features[2] <= 5.1499998569488525) {
if (features[3] <= 1.75) {
if (features[0] <= 6.5) {
if (features[2] <= 4.950000047683716) {
classes[0] = 0;
classes[1] = 1;
classes[2] = 0;
} else {
if (features[0] <= 6.150000095367432) {
if (features[3] <= 1.550000011920929) {
classes[0] = 0;
classes[1] = 0;
classes[2] = 2;
} else {
classes[0] = 0;
classes[1] = 1;
classes[2] = 0;
}
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 2;
}
}
} else {
classes[0] = 0;
classes[1] = 2;
classes[2] = 0;
}
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 13;
}
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 34;
}
}
}
int class_idx = 0;
int class_val = classes[0];
for (int i = 1; i < 3; i++) {
if (classes[i] > class_val) {
class_idx = i;
class_val = classes[i];
}
}
return class_idx;
}
public static int predict_8(double[] features) {
int[] classes = new int[3];
if (features[3] <= 0.7000000029802322) {
classes[0] = 42;
classes[1] = 0;
classes[2] = 0;
} else {
if (features[0] <= 6.25) {
if (features[2] <= 4.799999952316284) {
if (features[0] <= 4.950000047683716) {
if (features[1] <= 2.450000047683716) {
classes[0] = 0;
classes[1] = 1;
classes[2] = 0;
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 3;
}
} else {
classes[0] = 0;
classes[1] = 36;
classes[2] = 0;
}
} else {
if (features[3] <= 1.550000011920929) {
classes[0] = 0;
classes[1] = 0;
classes[2] = 4;
} else {
if (features[3] <= 1.699999988079071) {
classes[0] = 0;
classes[1] = 2;
classes[2] = 0;
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 4;
}
}
}
} else {
if (features[3] <= 1.75) {
if (features[2] <= 5.049999952316284) {
classes[0] = 0;
classes[1] = 15;
classes[2] = 0;
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 4;
}
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 39;
}
}
}
int class_idx = 0;
int class_val = classes[0];
for (int i = 1; i < 3; i++) {
if (classes[i] > class_val) {
class_idx = i;
class_val = classes[i];
}
}
return class_idx;
}
public static int predict_9(double[] features) {
int[] classes = new int[3];
if (features[2] <= 2.599999964237213) {
classes[0] = 55;
classes[1] = 0;
classes[2] = 0;
} else {
if (features[2] <= 4.950000047683716) {
if (features[0] <= 5.950000047683716) {
classes[0] = 0;
classes[1] = 23;
classes[2] = 0;
} else {
if (features[3] <= 1.649999976158142) {
classes[0] = 0;
classes[1] = 16;
classes[2] = 0;
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 4;
}
}
} else {
if (features[0] <= 6.599999904632568) {
classes[0] = 0;
classes[1] = 0;
classes[2] = 33;
} else {
if (features[0] <= 6.75) {
if (features[3] <= 2.0) {
classes[0] = 0;
classes[1] = 1;
classes[2] = 0;
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 4;
}
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 14;
}
}
}
}
int class_idx = 0;
int class_val = classes[0];
for (int i = 1; i < 3; i++) {
if (classes[i] > class_val) {
class_idx = i;
class_val = classes[i];
}
}
return class_idx;
}
public static int predict_10(double[] features) {
int[] classes = new int[3];
if (features[3] <= 0.800000011920929) {
classes[0] = 52;
classes[1] = 0;
classes[2] = 0;
} else {
if (features[2] <= 4.75) {
classes[0] = 0;
classes[1] = 37;
classes[2] = 0;
} else {
if (features[3] <= 1.75) {
if (features[2] <= 4.950000047683716) {
classes[0] = 0;
classes[1] = 4;
classes[2] = 0;
} else {
if (features[1] <= 2.649999976158142) {
classes[0] = 0;
classes[1] = 0;
classes[2] = 2;
} else {
if (features[3] <= 1.550000011920929) {
classes[0] = 0;
classes[1] = 0;
classes[2] = 2;
} else {
if (features[2] <= 5.450000047683716) {
classes[0] = 0;
classes[1] = 2;
classes[2] = 0;
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 1;
}
}
}
}
} else {
if (features[2] <= 4.8500001430511475) {
if (features[1] <= 3.100000023841858) {
classes[0] = 0;
classes[1] = 0;
classes[2] = 6;
} else {
classes[0] = 0;
classes[1] = 1;
classes[2] = 0;
}
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 43;
}
}
}
}
int class_idx = 0;
int class_val = classes[0];
for (int i = 1; i < 3; i++) {
if (classes[i] > class_val) {
class_idx = i;
class_val = classes[i];
}
}
return class_idx;
}
public static int predict_11(double[] features) {
int[] classes = new int[3];
if (features[2] <= 2.599999964237213) {
classes[0] = 47;
classes[1] = 0;
classes[2] = 0;
} else {
if (features[2] <= 4.75) {
classes[0] = 0;
classes[1] = 40;
classes[2] = 0;
} else {
if (features[2] <= 4.950000047683716) {
if (features[1] <= 3.049999952316284) {
if (features[3] <= 1.5999999642372131) {
classes[0] = 0;
classes[1] = 2;
classes[2] = 0;
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 7;
}
} else {
classes[0] = 0;
classes[1] = 2;
classes[2] = 0;
}
} else {
if (features[0] <= 6.049999952316284) {
if (features[2] <= 5.049999952316284) {
classes[0] = 0;
classes[1] = 0;
classes[2] = 4;
} else {
if (features[0] <= 5.950000047683716) {
classes[0] = 0;
classes[1] = 0;
classes[2] = 7;
} else {
classes[0] = 0;
classes[1] = 1;
classes[2] = 0;
}
}
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 40;
}
}
}
}
int class_idx = 0;
int class_val = classes[0];
for (int i = 1; i < 3; i++) {
if (classes[i] > class_val) {
class_idx = i;
class_val = classes[i];
}
}
return class_idx;
}
public static int predict_12(double[] features) {
int[] classes = new int[3];
if (features[3] <= 0.800000011920929) {
classes[0] = 54;
classes[1] = 0;
classes[2] = 0;
} else {
if (features[1] <= 2.450000047683716) {
if (features[2] <= 4.75) {
classes[0] = 0;
classes[1] = 12;
classes[2] = 0;
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 1;
}
} else {
if (features[3] <= 1.600000023841858) {
if (features[2] <= 5.0) {
classes[0] = 0;
classes[1] = 23;
classes[2] = 0;
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 2;
}
} else {
if (features[3] <= 1.75) {
if (features[0] <= 5.799999952316284) {
classes[0] = 0;
classes[1] = 0;
classes[2] = 3;
} else {
classes[0] = 0;
classes[1] = 2;
classes[2] = 0;
}
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 53;
}
}
}
}
int class_idx = 0;
int class_val = classes[0];
for (int i = 1; i < 3; i++) {
if (classes[i] > class_val) {
class_idx = i;
class_val = classes[i];
}
}
return class_idx;
}
public static int predict_13(double[] features) {
int[] classes = new int[3];
if (features[0] <= 5.450000047683716) {
if (features[3] <= 0.800000011920929) {
classes[0] = 36;
classes[1] = 0;
classes[2] = 0;
} else {
if (features[2] <= 4.200000047683716) {
classes[0] = 0;
classes[1] = 6;
classes[2] = 0;
} else {
if (features[1] <= 2.75) {
classes[0] = 0;
classes[1] = 0;
classes[2] = 1;
} else {
classes[0] = 0;
classes[1] = 1;
classes[2] = 0;
}
}
}
} else {
if (features[2] <= 4.900000095367432) {
if (features[1] <= 3.600000023841858) {
classes[0] = 0;
classes[1] = 43;
classes[2] = 0;
} else {
classes[0] = 7;
classes[1] = 0;
classes[2] = 0;
}
} else {
if (features[3] <= 1.699999988079071) {
if (features[3] <= 1.550000011920929) {
classes[0] = 0;
classes[1] = 0;
classes[2] = 2;
} else {
classes[0] = 0;
classes[1] = 4;
classes[2] = 0;
}
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 50;
}
}
}
int class_idx = 0;
int class_val = classes[0];
for (int i = 1; i < 3; i++) {
if (classes[i] > class_val) {
class_idx = i;
class_val = classes[i];
}
}
return class_idx;
}
public static int predict_14(double[] features) {
int[] classes = new int[3];
if (features[2] <= 2.599999964237213) {
classes[0] = 52;
classes[1] = 0;
classes[2] = 0;
} else {
if (features[3] <= 1.699999988079071) {
if (features[0] <= 7.0) {
if (features[2] <= 5.0) {
classes[0] = 0;
classes[1] = 48;
classes[2] = 0;
} else {
if (features[0] <= 6.049999952316284) {
classes[0] = 0;
classes[1] = 1;
classes[2] = 0;
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 2;
}
}
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 1;
}
} else {
classes[0] = 0;
classes[1] = 0;
classes[2] = 46;
}
}
int class_idx = 0;
int class_val = classes[0];
for (int i = 1; i < 3; i++) {
if (classes[i] > class_val) {
class_idx = i;
class_val = classes[i];
}
}
return class_idx;
}
public static int predict(double[] features) {
int n_classes = 3;
int[] classes = new int[n_classes];
classes[RandomForestClassifier.predict_0(features)]++;
classes[RandomForestClassifier.predict_1(features)]++;
classes[RandomForestClassifier.predict_2(features)]++;
classes[RandomForestClassifier.predict_3(features)]++;
classes[RandomForestClassifier.predict_4(features)]++;
classes[RandomForestClassifier.predict_5(features)]++;
classes[RandomForestClassifier.predict_6(features)]++;
classes[RandomForestClassifier.predict_7(features)]++;
classes[RandomForestClassifier.predict_8(features)]++;
classes[RandomForestClassifier.predict_9(features)]++;
classes[RandomForestClassifier.predict_10(features)]++;
classes[RandomForestClassifier.predict_11(features)]++;
classes[RandomForestClassifier.predict_12(features)]++;
classes[RandomForestClassifier.predict_13(features)]++;
classes[RandomForestClassifier.predict_14(features)]++;
int class_idx = 0;
int class_val = classes[0];
for (int i = 1; i < n_classes; i++) {
if (classes[i] > class_val) {
class_idx = i;
class_val = classes[i];
}
}
return class_idx;
}
public static void main(String[] args) {
if (args.length == 4) {
// Features:
double[] features = new double[args.length];
for (int i = 0, l = args.length; i < l; i++) {
features[i] = Double.parseDouble(args[i]);
}
// Prediction:
int prediction = RandomForestClassifier.predict(features);
System.out.println(prediction);
}
}
}
In [5]:
# Save classifier:
# with open('RandomForestClassifier.java', 'w') as f:
# f.write(output)
# Compile model:
# $ javac -cp . RandomForestClassifier.java
# Run classification:
# $ java RandomForestClassifier 1 2 3 4
Content source: nok/sklearn-porter
Similar notebooks: