sklearn-porter

Repository: https://github.com/nok/sklearn-porter

MLPClassifier

Documentation: sklearn.neural_network.MLPClassifier


In [1]:
import sys
sys.path.append('../../../../..')

Load data


In [2]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle

iris_data = load_iris()
X = iris_data.data
y = iris_data.target

X = shuffle(X, random_state=0)
y = shuffle(y, random_state=0)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.4, random_state=5)

print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)


((90, 4), (90,))
((60, 4), (60,))

Train classifier


In [3]:
from sklearn.neural_network import MLPClassifier

clf = MLPClassifier(activation='relu', hidden_layer_sizes=50,
                    max_iter=500, alpha=1e-4, solver='sgd',
                    tol=1e-4, random_state=1, learning_rate_init=.1)
clf.fit(X_train, y_train)


Out[3]:
MLPClassifier(activation='relu', alpha=0.0001, batch_size='auto', beta_1=0.9,
       beta_2=0.999, early_stopping=False, epsilon=1e-08,
       hidden_layer_sizes=50, learning_rate='constant',
       learning_rate_init=0.1, max_iter=500, momentum=0.9,
       n_iter_no_change=10, nesterovs_momentum=True, power_t=0.5,
       random_state=1, shuffle=True, solver='sgd', tol=0.0001,
       validation_fraction=0.1, verbose=False, warm_start=False)

Transpile classifier


In [4]:
from sklearn_porter import Porter

porter = Porter(clf, language='java')
output = porter.export()

print(output)


class MLPClassifier {

    private enum Activation { IDENTITY, LOGISTIC, RELU, TANH, SOFTMAX }

    private Activation hidden;
    private Activation output;
    private double[][] network;
    private double[][][] weights;
    private double[][] bias;

    public MLPClassifier(String hidden, String output, int[] layers, double[][][] weights, double[][] bias) {
        this.hidden = Activation.valueOf(hidden.toUpperCase());
        this.output = Activation.valueOf(output.toUpperCase());
        this.network = new double[layers.length + 1][];
        for (int i = 0, l = layers.length; i < l; i++) {
            this.network[i + 1] = new double[layers[i]];
        }
        this.weights = weights;
        this.bias = bias;
    }

    public MLPClassifier(String hidden, String output, int neurons, double[][][] weights, double[][] bias) {
        this(hidden, output, new int[] { neurons }, weights, bias);
    }

    private double[] compute(Activation activation, double[] v) {
        switch (activation) {
            case LOGISTIC:
                for (int i = 0, l = v.length; i < l; i++) {
                    v[i] = 1. / (1. + Math.exp(-v[i]));
                }
                break;
            case RELU:
                for (int i = 0, l = v.length; i < l; i++) {
                    v[i] = Math.max(0, v[i]);
                }
                break;
            case TANH:
                for (int i = 0, l = v.length; i < l; i++) {
                    v[i] = Math.tanh(v[i]);
                }
                break;
            case SOFTMAX:
                double max = Double.NEGATIVE_INFINITY;
                for (double x : v) {
                    if (x > max) {
                        max = x;
                    }
                }
                for (int i = 0, l = v.length; i < l; i++) {
                    v[i] = Math.exp(v[i] - max);
                }
                double sum = 0.;
                for (double x : v) {
                    sum += x;
                }
                for (int i = 0, l = v.length; i < l; i++) {
                    v[i] /= sum;
                }
                break;
        }
        return v;
    }

    public int predict(double[] neurons) {
        this.network[0] = neurons;

        for (int i = 0; i < this.network.length - 1; i++) {
            for (int j = 0; j < this.network[i + 1].length; j++) {
                for (int l = 0; l < this.network[i].length; l++) {
                    this.network[i + 1][j] += this.network[i][l] * this.weights[i][l][j];
                }
                this.network[i + 1][j] += this.bias[i][j];
            }
            if ((i + 1) < (this.network.length - 1)) {
                this.network[i + 1] = this.compute(this.hidden, this.network[i + 1]);
            }
        }
        this.network[this.network.length - 1] = this.compute(this.output, this.network[this.network.length - 1]);

        if (this.network[this.network.length - 1].length == 1) {
            if (this.network[this.network.length - 1][0] > .5) {
                return 1;
            }
            return 0;
        } else {
            int classIdx = 0;
            for (int i = 0; i < this.network[this.network.length - 1].length; i++) {
                classIdx = this.network[this.network.length - 1][i] > this.network[this.network.length - 1][classIdx] ? i : classIdx;
            }
            return classIdx;
        }

    }

    public static void main(String[] args) {
        if (args.length == 4) {

            // Features:
            double[] features = new double[args.length];
            for (int i = 0, l = args.length; i < l; i++) {
                features[i] = Double.parseDouble(args[i]);
            }

            // Parameters:
            int[] layers = {50, 3};
            double[][][] weights = {{{-0.05531723699609843, -0.287200025420079, -0.3332484895550319, -0.13177488666652298, -0.23548999991027228, -0.27176726174820537, -0.20915446535095505, -0.10106487799341365, -0.7768321356902204, -1.5542208226009833, -0.05386893454482406, -0.24905294696360633, -0.4175880045160095, -0.5451600682317279, -0.3150668128418591, -0.7422815097649155, -0.49552610461769975, -0.5997542299234682, -0.6973876805774772, -0.2012604837848982, 0.3889977069162521, 0.2273146629978264, -0.12438067366550552, 0.16600567918771103, -0.3366684197480092, 0.24592407503337607, -0.2766300586111286, -0.30728888673110144, -0.2201073774364332, 0.1862932202140711, -0.26776187233074394, -0.5956833945174421, -0.7478009637138988, 0.022109619813295678, -0.05700975229010219, -0.31868691845257185, -1.0423827254861882, -1.275824755479885, -0.3211328669780091, -1.0750119547979755, 0.11157830632073124, -0.12687180016606509, -0.1463668974088521, -0.15067020097080916, -0.26450917438245003, -0.034736753415490894, -0.08071131537953455, -0.13758701962047676, -0.14147945909943982, -0.2466412581379032}, {-0.32041376517165965, -0.05836142266591636, -0.19224296505426625, -0.1562981964260091, -0.005617748941114245, -0.29775062478352105, 0.04941046278461, -0.3453562427114212, -0.21581071070780053, -1.0040949497122218, -0.26510354422122756, -0.3080856034380936, -0.056996092855997375, -0.6131605204837428, -0.3000232902894183, -0.5321817288610283, -0.09509012584465502, -0.3020655499512411, -0.06974417619403356, 0.05770187230353784, 0.785112825079831, -0.37824248939802313, -0.2404762337015279, 0.5339547417111554, -0.6340088791166932, -0.2940844445115775, 0.28499837065076183, -0.10148680966767495, 0.1672037568627906, 0.27233503316006796, 0.2555308044413266, -0.12214235570228323, -0.4679693956018563, -0.1007318409821162, -0.42269305538095053, 0.12552593103638957, -0.6411809912790306, -0.782507889424831, 0.10895818870308205, -0.5867802036465367, -0.28938077197085876, 0.7055663869707898, -0.03339104991615795, -0.2392580455795085, -0.06124055220494313, -0.17531082554277327, 0.018062479293991646, 0.04911839110348018, -0.3314112354587498, 0.07809459515643513}, {-0.11556708521075042, -0.31339762843888086, 0.2572880978604231, -0.09515103955152604, 0.27234974385273425, 0.08223795642039296, -0.32277751423050166, 0.6643666067538009, -0.5868777228196822, 0.6762125018848691, -0.21843402806870219, -0.352561909073705, 0.2672384369924125, 0.32312472782708235, -0.28932575697624946, -0.19284652298178176, -0.17472825425090205, -0.20936801993985416, 0.11380387389520923, -0.2504795659186837, -1.2604871953609638, -0.13525225310626718, -0.31445423209067447, -0.812547016918379, 0.9534926082450407, 0.2011284844354027, 0.035213744355836595, 0.22801471480271154, -0.25054466214146737, -0.6010638644943938, 0.05717137328838323, -0.2020833240639735, -0.39580233324722136, -0.32089353179754704, 0.5908996376284825, -0.23293611412081117, -0.7098164781888164, -0.40823182141354747, 0.242354986458547, -0.6730438498781298, -0.27149973362266727, -1.6926542893107992, -0.2933806412143908, -0.18116012868489537, -0.3036242510277751, -0.2616638327399045, -0.3121301087393156, 0.14198899193882258, 0.039810294731032425, -0.3249542998866052}, {-0.285343121703217, 0.22155141721492505, 0.04539913718168689, -0.1977994093225928, -0.16511191233684375, 0.1625463775944756, -0.2030417764988245, 0.2805740749369515, 0.03874959579726595, 0.8204698598903679, -0.17343035476714458, -0.024664581823196405, 0.07873934733428413, 0.5543720394060162, -0.22879983655199124, -0.4341090932685532, -0.39670659359541904, -0.18604545473956502, 0.0878518387161701, 0.045899774382148664, -0.5875646578556978, 0.41857316005322753, 0.05316210853870095, -0.359401427845156, 0.6252361981746992, 0.255409584031532, 0.1128190195621137, -0.156716253451322, -0.2891026549254114, -0.305845271799076, 0.08647610794566317, -0.37228388641979465, 0.07850815306461306, -0.28896822710828896, 0.07400382553583477, 0.19336868534148596, -0.5601674371040103, 0.04040233206845901, 0.016446448623794588, -0.016809639315931692, -0.27331602665209465, -0.9155318605037801, 0.15670660100122782, 0.255277940082994, 0.27187022394120364, 0.2879739531161909, -0.3729824176631669, -0.17708737580969713, 0.07785023037880102, 0.29933649442316046}}, {{-0.3361840896139352, 0.32081563270400965, -0.08305024558717926}, {-0.011298257624238184, 0.2562037264645608, 0.36565565904247127}, {0.05027406123597418, 0.08618365620838876, -0.14428769955301007}, {0.05843095655112261, 0.16824195293379818, 0.24111268839666325}, {0.17164715947475379, 0.13327455089508367, 0.2452615732665165}, {-0.11931959419316242, 0.11492535400029204, -0.03305738170696759}, {-0.07933414655584556, -0.06001586600463217, -0.06629529779914677}, {-0.8683787278287711, 0.27767857503847926, 0.5029195369584982}, {-0.11245016848906347, 0.2826229276566894, 0.06546178954645863}, {-0.6353164852707057, 0.12554082668192784, 0.555322228567681}, {0.25570435642332295, 0.2717489802433174, 0.10949566393585161}, {-0.03238870338066207, -0.23808567718655996, 0.18802476803113657}, {-0.16403462094143173, -0.04754329486770832, 0.48233253248114044}, {-0.3479430419801183, 0.27323594914402793, 0.42666423415569105}, {0.046335177624920613, -0.023074139253265364, -0.10585608609482347}, {-0.049462559788541895, -0.5541623159350335, -0.051950479717014}, {0.31541096979521055, -0.2110142438158008, 0.21610231218402715}, {-0.22272121264684921, 0.13386214408938715, 0.5147226698060525}, {-0.021978343487825205, -0.38712752123919403, 0.05449852018510285}, {-0.18502315074911943, 0.06225209053889778, -0.1263253592882017}, {1.4345239976530477, -0.8176950775583972, -0.22448288588181867}, {-0.2842463512768766, -0.2313683452189076, 0.04689183310794551}, {0.15381740622571935, -0.19635865627055554, -0.16955054465901914}, {0.7438098762165719, -0.22439680979960874, -0.2608775306709127}, {-0.7182308809245205, -0.09618233208229632, 0.3780251268086543}, {-0.7164784542621164, 0.22684648185168335, 0.3224940122522215}, {0.029614815240288896, 0.10372044373398122, -0.2391885875166292}, {0.16925539019209196, -0.18703569980133156, 0.013022021267097478}, {0.1919783311717777, -0.32142826473012015, -0.11818812229201414}, {0.8643336346706185, -0.3925883592699291, 0.03702294126383572}, {0.24669409212237822, 0.30267860407806646, 0.21964228201240318}, {-0.16933107162740296, -0.10368470623812824, 0.3431081419486641}, {0.32704951736404847, -0.27309365479113207, 0.3587885337135242}, {-0.31319132287025997, 0.1818461392583339, 0.15593231469800903}, {-0.8394767452198691, 0.12118238796314447, 0.48215099989482807}, {-0.08128776148447113, 0.18552285385421144, -0.04499729682607364}, {-0.08279347807430275, 0.17955519951422247, 0.28913122476828207}, {0.8364367692241214, -1.1402922591937532, 0.36164065744120727}, {0.31945717823254016, 0.09192236185887058, 0.3323586339416823}, {-0.29788778168453406, -0.3349373087798128, 0.43628488510789515}, {-0.38373983087179053, -0.1300063596571257, -0.13068625075532805}, {0.8966978503571496, -0.15396137465027315, -0.9832477068576775}, {0.24382613862453847, 0.1296546148631104, 0.1284867301263389}, {0.2902051898628873, -0.3101814523419354, -0.17374186389051466}, {0.3295584453519629, -0.1992442345041733, -0.16975285927075562}, {-0.16003592748770948, 0.16834332643725702, -0.028951699361519388}, {0.8838381734457926, -1.0264009967355232, -0.34367745531628613}, {0.2009335529387164, -0.1363775843696629, -0.3178782755787516}, {0.06287155778970704, 0.23137339887648598, -0.08006534854786675}, {0.1681319639169432, 0.007497203601473629, 0.027556847615859456}}};
            double[][] bias = {{0.30011741283138643, -0.035979900154678855, 0.27707089984418304, 0.09437747263089169, -0.0733281905725025, -0.009339555268726818, 0.06954032194664883, -0.005685230901411666, 0.18615368012821823, -0.15027837700332358, -0.07008292471763006, 0.23571139223322518, -0.261230851707333, -0.47151668818437686, -0.24328056130217912, -0.15169537034679342, -0.3886910386782494, 0.1981680192932941, 0.12179738516009957, -0.32332067950525173, -0.09202577409330663, -0.14370913611358427, -0.24600210345938872, 0.26240311472639255, -0.3265337728594152, 0.27649270655201935, 0.05467611996472005, 0.2525546562745627, 0.22982296359481452, 0.2950313004271297, -0.02674648945546204, -0.05000417024200673, -0.0011395225086257164, -0.1428540988439016, -0.08278558656426527, 0.028427639916320597, -0.5003565563714583, -0.23491012179639284, -0.044215767340361145, 0.0023483266956502162, -0.14837939545281392, 0.31045037001889203, 0.05190481018969034, -0.29384801677586486, 0.19195282255033613, 0.0746874513632893, -0.36145703689926734, -0.05320421333257852, 0.11937922437695309, 0.2790678519850172}, {0.30487821665512077, 0.7354282585071136, -0.8413097834553672}};

            // Prediction:
            MLPClassifier clf = new MLPClassifier("relu", "softmax", layers, weights, bias);
            int estimation = clf.predict(features);
            System.out.println(estimation);

        }
    }
}

Run classification in Java


In [5]:
# Save classifier:
# with open('MLPClassifier.java', 'w') as f:
#     f.write(output)

# Compile model:
# $ javac -cp . MLPClassifier.java

# Run classification:
# $ java MLPClassifier 1 2 3 4