sklearn-porter

Repository: https://github.com/nok/sklearn-porter

MLPClassifier

Documentation: sklearn.neural_network.MLPClassifier


In [1]:
import sys
sys.path.append('../../../../..')

Load data


In [2]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle

iris_data = load_iris()
X = iris_data.data
y = iris_data.target

X = shuffle(X, random_state=0)
y = shuffle(y, random_state=0)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.4, random_state=5)

print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)


((90, 4), (90,))
((60, 4), (60,))

Train classifier


In [3]:
from sklearn.neural_network import MLPClassifier

clf = MLPClassifier(activation='relu', hidden_layer_sizes=50,
                    max_iter=500, alpha=1e-4, solver='sgd',
                    tol=1e-4, random_state=1, learning_rate_init=.1)
clf.fit(X_train, y_train)


Out[3]:
MLPClassifier(activation='relu', alpha=0.0001, batch_size='auto', beta_1=0.9,
       beta_2=0.999, early_stopping=False, epsilon=1e-08,
       hidden_layer_sizes=50, learning_rate='constant',
       learning_rate_init=0.1, max_iter=500, momentum=0.9,
       n_iter_no_change=10, nesterovs_momentum=True, power_t=0.5,
       random_state=1, shuffle=True, solver='sgd', tol=0.0001,
       validation_fraction=0.1, verbose=False, warm_start=False)

Transpile classifier


In [4]:
from sklearn_porter import Porter

porter = Porter(clf, language='js')
output = porter.export()

print(output)


var MLPClassifier = function(hidden, output, layers, weights, bias) {

    this.hidden = hidden.toUpperCase();
    this.output = output.toUpperCase();
    this.network = new Array(layers.length + 1);
    for (var i = 0, l = layers.length; i < l; i++) {
        this.network[i + 1] = new Array(layers[i]).fill(0.);
    }
    this.weights = weights;
    this.bias = bias;

    var compute = function(activation, v) {
        switch (activation) {
            case 'LOGISTIC':
                for (var i = 0, l = v.length; i < l; i++) {
                    v[i] = 1. / (1. + Math.exp(-v[i]));
                }
                break;
            case 'RELU':
                for (var i = 0, l = v.length; i < l; i++) {
                    v[i] = Math.max(0, v[i]);
                }
                break;
            case 'TANH':
                for (var i = 0, l = v.length; i < l; i++) {
                    v[i] = Math.tanh(v[i]);
                }
                break;
            case 'SOFTMAX':
                var max = Number.NEGATIVE_INFINITY;
                for (var i = 0, l = v.length; i < l; i++) {
                    if (v[i] > max) {
                        max = v[i];
                    }
                }
                for (var i = 0, l = v.length; i < l; i++) {
                    v[i] = Math.exp(v[i] - max);
                }
                var sum = 0.0;
                for (var i = 0, l = v.length; i < l; i++) {
                    sum += v[i];
                }
                for (var i = 0, l = v.length; i < l; i++) {
                    v[i] /= sum;
                }
                break;
        }
        return v;
    };

    this.predict = function(neurons) {
        this.network[0] = neurons;

        for (var i = 0; i < this.network.length - 1; i++) {
            for (var j = 0; j < this.network[i + 1].length; j++) {
                for (var l = 0; l < this.network[i].length; l++) {
                    this.network[i + 1][j] += this.network[i][l] * this.weights[i][l][j];
                }
                this.network[i + 1][j] += this.bias[i][j];
            }
            if ((i + 1) < (this.network.length - 1)) {
                this.network[i + 1] = compute(this.hidden, this.network[i + 1]);
            }
        }
        this.network[this.network.length - 1] = compute(this.output, this.network[this.network.length - 1]);

        if (this.network[this.network.length - 1].length == 1) {
            if (this.network[this.network.length - 1][0] > .5) {
                return 1;
            }
            return 0;
        } else {
            var classIdx = 0;
            for (var i = 0, l = this.network[this.network.length - 1].length; i < l; i++) {
                classIdx = this.network[this.network.length - 1][i] > this.network[this.network.length - 1][classIdx] ? i : classIdx;
            }
            return classIdx;
        }

    };

};

if (typeof process !== 'undefined' && typeof process.argv !== 'undefined') {
    if (process.argv.length - 2 === 4) {

        // Features:
        var features = process.argv.slice(2);

        // Parameters:
        const layers = [50, 3];
        const weights = [[[-0.05531723699609843, -0.287200025420079, -0.3332484895550319, -0.13177488666652298, -0.23548999991027228, -0.27176726174820537, -0.20915446535095505, -0.10106487799341365, -0.7768321356902204, -1.5542208226009833, -0.05386893454482406, -0.24905294696360633, -0.4175880045160095, -0.5451600682317279, -0.3150668128418591, -0.7422815097649155, -0.49552610461769975, -0.5997542299234682, -0.6973876805774772, -0.2012604837848982, 0.3889977069162521, 0.2273146629978264, -0.12438067366550552, 0.16600567918771103, -0.3366684197480092, 0.24592407503337607, -0.2766300586111286, -0.30728888673110144, -0.2201073774364332, 0.1862932202140711, -0.26776187233074394, -0.5956833945174421, -0.7478009637138988, 0.022109619813295678, -0.05700975229010219, -0.31868691845257185, -1.0423827254861882, -1.275824755479885, -0.3211328669780091, -1.0750119547979755, 0.11157830632073124, -0.12687180016606509, -0.1463668974088521, -0.15067020097080916, -0.26450917438245003, -0.034736753415490894, -0.08071131537953455, -0.13758701962047676, -0.14147945909943982, -0.2466412581379032], [-0.32041376517165965, -0.05836142266591636, -0.19224296505426625, -0.1562981964260091, -0.005617748941114245, -0.29775062478352105, 0.04941046278461, -0.3453562427114212, -0.21581071070780053, -1.0040949497122218, -0.26510354422122756, -0.3080856034380936, -0.056996092855997375, -0.6131605204837428, -0.3000232902894183, -0.5321817288610283, -0.09509012584465502, -0.3020655499512411, -0.06974417619403356, 0.05770187230353784, 0.785112825079831, -0.37824248939802313, -0.2404762337015279, 0.5339547417111554, -0.6340088791166932, -0.2940844445115775, 0.28499837065076183, -0.10148680966767495, 0.1672037568627906, 0.27233503316006796, 0.2555308044413266, -0.12214235570228323, -0.4679693956018563, -0.1007318409821162, -0.42269305538095053, 0.12552593103638957, -0.6411809912790306, -0.782507889424831, 0.10895818870308205, -0.5867802036465367, -0.28938077197085876, 0.7055663869707898, -0.03339104991615795, -0.2392580455795085, -0.06124055220494313, -0.17531082554277327, 0.018062479293991646, 0.04911839110348018, -0.3314112354587498, 0.07809459515643513], [-0.11556708521075042, -0.31339762843888086, 0.2572880978604231, -0.09515103955152604, 0.27234974385273425, 0.08223795642039296, -0.32277751423050166, 0.6643666067538009, -0.5868777228196822, 0.6762125018848691, -0.21843402806870219, -0.352561909073705, 0.2672384369924125, 0.32312472782708235, -0.28932575697624946, -0.19284652298178176, -0.17472825425090205, -0.20936801993985416, 0.11380387389520923, -0.2504795659186837, -1.2604871953609638, -0.13525225310626718, -0.31445423209067447, -0.812547016918379, 0.9534926082450407, 0.2011284844354027, 0.035213744355836595, 0.22801471480271154, -0.25054466214146737, -0.6010638644943938, 0.05717137328838323, -0.2020833240639735, -0.39580233324722136, -0.32089353179754704, 0.5908996376284825, -0.23293611412081117, -0.7098164781888164, -0.40823182141354747, 0.242354986458547, -0.6730438498781298, -0.27149973362266727, -1.6926542893107992, -0.2933806412143908, -0.18116012868489537, -0.3036242510277751, -0.2616638327399045, -0.3121301087393156, 0.14198899193882258, 0.039810294731032425, -0.3249542998866052], [-0.285343121703217, 0.22155141721492505, 0.04539913718168689, -0.1977994093225928, -0.16511191233684375, 0.1625463775944756, -0.2030417764988245, 0.2805740749369515, 0.03874959579726595, 0.8204698598903679, -0.17343035476714458, -0.024664581823196405, 0.07873934733428413, 0.5543720394060162, -0.22879983655199124, -0.4341090932685532, -0.39670659359541904, -0.18604545473956502, 0.0878518387161701, 0.045899774382148664, -0.5875646578556978, 0.41857316005322753, 0.05316210853870095, -0.359401427845156, 0.6252361981746992, 0.255409584031532, 0.1128190195621137, -0.156716253451322, -0.2891026549254114, -0.305845271799076, 0.08647610794566317, -0.37228388641979465, 0.07850815306461306, -0.28896822710828896, 0.07400382553583477, 0.19336868534148596, -0.5601674371040103, 0.04040233206845901, 0.016446448623794588, -0.016809639315931692, -0.27331602665209465, -0.9155318605037801, 0.15670660100122782, 0.255277940082994, 0.27187022394120364, 0.2879739531161909, -0.3729824176631669, -0.17708737580969713, 0.07785023037880102, 0.29933649442316046]], [[-0.3361840896139352, 0.32081563270400965, -0.08305024558717926], [-0.011298257624238184, 0.2562037264645608, 0.36565565904247127], [0.05027406123597418, 0.08618365620838876, -0.14428769955301007], [0.05843095655112261, 0.16824195293379818, 0.24111268839666325], [0.17164715947475379, 0.13327455089508367, 0.2452615732665165], [-0.11931959419316242, 0.11492535400029204, -0.03305738170696759], [-0.07933414655584556, -0.06001586600463217, -0.06629529779914677], [-0.8683787278287711, 0.27767857503847926, 0.5029195369584982], [-0.11245016848906347, 0.2826229276566894, 0.06546178954645863], [-0.6353164852707057, 0.12554082668192784, 0.555322228567681], [0.25570435642332295, 0.2717489802433174, 0.10949566393585161], [-0.03238870338066207, -0.23808567718655996, 0.18802476803113657], [-0.16403462094143173, -0.04754329486770832, 0.48233253248114044], [-0.3479430419801183, 0.27323594914402793, 0.42666423415569105], [0.046335177624920613, -0.023074139253265364, -0.10585608609482347], [-0.049462559788541895, -0.5541623159350335, -0.051950479717014], [0.31541096979521055, -0.2110142438158008, 0.21610231218402715], [-0.22272121264684921, 0.13386214408938715, 0.5147226698060525], [-0.021978343487825205, -0.38712752123919403, 0.05449852018510285], [-0.18502315074911943, 0.06225209053889778, -0.1263253592882017], [1.4345239976530477, -0.8176950775583972, -0.22448288588181867], [-0.2842463512768766, -0.2313683452189076, 0.04689183310794551], [0.15381740622571935, -0.19635865627055554, -0.16955054465901914], [0.7438098762165719, -0.22439680979960874, -0.2608775306709127], [-0.7182308809245205, -0.09618233208229632, 0.3780251268086543], [-0.7164784542621164, 0.22684648185168335, 0.3224940122522215], [0.029614815240288896, 0.10372044373398122, -0.2391885875166292], [0.16925539019209196, -0.18703569980133156, 0.013022021267097478], [0.1919783311717777, -0.32142826473012015, -0.11818812229201414], [0.8643336346706185, -0.3925883592699291, 0.03702294126383572], [0.24669409212237822, 0.30267860407806646, 0.21964228201240318], [-0.16933107162740296, -0.10368470623812824, 0.3431081419486641], [0.32704951736404847, -0.27309365479113207, 0.3587885337135242], [-0.31319132287025997, 0.1818461392583339, 0.15593231469800903], [-0.8394767452198691, 0.12118238796314447, 0.48215099989482807], [-0.08128776148447113, 0.18552285385421144, -0.04499729682607364], [-0.08279347807430275, 0.17955519951422247, 0.28913122476828207], [0.8364367692241214, -1.1402922591937532, 0.36164065744120727], [0.31945717823254016, 0.09192236185887058, 0.3323586339416823], [-0.29788778168453406, -0.3349373087798128, 0.43628488510789515], [-0.38373983087179053, -0.1300063596571257, -0.13068625075532805], [0.8966978503571496, -0.15396137465027315, -0.9832477068576775], [0.24382613862453847, 0.1296546148631104, 0.1284867301263389], [0.2902051898628873, -0.3101814523419354, -0.17374186389051466], [0.3295584453519629, -0.1992442345041733, -0.16975285927075562], [-0.16003592748770948, 0.16834332643725702, -0.028951699361519388], [0.8838381734457926, -1.0264009967355232, -0.34367745531628613], [0.2009335529387164, -0.1363775843696629, -0.3178782755787516], [0.06287155778970704, 0.23137339887648598, -0.08006534854786675], [0.1681319639169432, 0.007497203601473629, 0.027556847615859456]]];
        const bias = [[0.30011741283138643, -0.035979900154678855, 0.27707089984418304, 0.09437747263089169, -0.0733281905725025, -0.009339555268726818, 0.06954032194664883, -0.005685230901411666, 0.18615368012821823, -0.15027837700332358, -0.07008292471763006, 0.23571139223322518, -0.261230851707333, -0.47151668818437686, -0.24328056130217912, -0.15169537034679342, -0.3886910386782494, 0.1981680192932941, 0.12179738516009957, -0.32332067950525173, -0.09202577409330663, -0.14370913611358427, -0.24600210345938872, 0.26240311472639255, -0.3265337728594152, 0.27649270655201935, 0.05467611996472005, 0.2525546562745627, 0.22982296359481452, 0.2950313004271297, -0.02674648945546204, -0.05000417024200673, -0.0011395225086257164, -0.1428540988439016, -0.08278558656426527, 0.028427639916320597, -0.5003565563714583, -0.23491012179639284, -0.044215767340361145, 0.0023483266956502162, -0.14837939545281392, 0.31045037001889203, 0.05190481018969034, -0.29384801677586486, 0.19195282255033613, 0.0746874513632893, -0.36145703689926734, -0.05320421333257852, 0.11937922437695309, 0.2790678519850172], [0.30487821665512077, 0.7354282585071136, -0.8413097834553672]];

        // Prediction:
        var clf = new MLPClassifier('relu', 'softmax', layers, weights, bias);
        var prediction = clf.predict(features);
        console.log(prediction);

    }
}

Run classification in JavaScript


In [5]:
# Save classifier:
# with open('MLPClassifier.js', 'w') as f:
#     f.write(output)

# Run classification:
# if hash node 2/dev/null; then
#     node MLPClassifier.js 1 2 3 4
# fi