sklearn-porter

Repository: https://github.com/nok/sklearn-porter

MLPRegressor

Documentation: sklearn.neural_network.MLPRegressor


In [1]:
import sys
sys.path.append('../../../../..')

Load data


In [2]:
from sklearn.datasets import load_diabetes

samples = load_diabetes()
X = samples.data
y = samples.target

print(X.shape, y.shape)


((442, 10), (442,))

Train regressor


In [3]:
from sklearn.neural_network import MLPRegressor

reg = MLPRegressor(
    activation='relu', hidden_layer_sizes=30, max_iter=500, alpha=1e-4,
    solver='sgd', tol=1e-4, random_state=1, learning_rate_init=.1)
reg.fit(X, y)


Out[3]:
MLPRegressor(activation='relu', alpha=0.0001, batch_size='auto', beta_1=0.9,
       beta_2=0.999, early_stopping=False, epsilon=1e-08,
       hidden_layer_sizes=30, learning_rate='constant',
       learning_rate_init=0.1, max_iter=500, momentum=0.9,
       n_iter_no_change=10, nesterovs_momentum=True, power_t=0.5,
       random_state=1, shuffle=True, solver='sgd', tol=0.0001,
       validation_fraction=0.1, verbose=False, warm_start=False)

Transpile regressor


In [4]:
from sklearn_porter import Porter

porter = Porter(reg, language='js')
output = porter.export()

print(output)


var MLPRegressor = function(hidden, layers, weights, bias) {

    this.hidden = hidden.toUpperCase();
    this.network = new Array(layers.length + 1);
    for (var i = 0, l = layers.length; i < l; i++) {
        this.network[i + 1] = new Array(layers[i]).fill(0.);
    }
    this.weights = weights;
    this.bias = bias;

    var compute = function(activation, v, nLayers) {
        switch (activation) {
            case 'LOGISTIC':
                if (nLayers > 1) {
                    for (var i = 0, l = v.length; i < l; i++) {
                        v[i] = 1. / (1. + Math.exp(-v[i]));
                    }
                } else {
                    for (var i = 0, l = v.length; i < l; i++) {
                        if (v[i] > 0) {
                            v[i] = -Math.log(1. + Math.exp(-v[i]));
                        } else {
                            v[i] = v[i] - Math.log(1. + Math.exp(-v[i]));
                        }
                    }
                }
                break;
            case 'RELU':
                for (var i = 0, l = v.length; i < l; i++) {
                    v[i] = Math.max(0, v[i]);
                }
                break;
            case 'TANH':
                for (var i = 0, l = v.length; i < l; i++) {
                    v[i] = Math.tanh(v[i]);
                }
                break;
            case 'SOFTMAX':
                var max = Number.NEGATIVE_INFINITY;
                for (var i = 0, l = v.length; i < l; i++) {
                    if (v[i] > max) {
                        max = v[i];
                    }
                }
                for (var i = 0, l = v.length; i < l; i++) {
                    v[i] = Math.exp(v[i] - max);
                }
                var sum = 0.0;
                for (var i = 0, l = v.length; i < l; i++) {
                    sum += v[i];
                }
                for (var i = 0, l = v.length; i < l; i++) {
                    v[i] /= sum;
                }
                break;
        }
        return v;
    };

    this.predict = function(neurons) {
        this.network[0] = neurons;
    
        for (var i = 0; i < this.network.length - 1; i++) {
            for (var j = 0; j < this.network[i + 1].length; j++) {
                for (var l = 0; l < this.network[i].length; l++) {
                    this.network[i + 1][j] += this.network[i][l] * this.weights[i][l][j];
                }
                this.network[i + 1][j] += this.bias[i][j];
            }
            if ((i + 1) < (this.network.length - 1)) {
                this.network[i + 1] = compute(this.hidden, this.network[i + 1], this.network.length);
            }
        }
    
        if (this.network[this.network.length - 1].length > 1) {
            return this.network[this.network.length - 1];
        }
        return this.network[this.network.length - 1][0];
    };

};

if (typeof process !== 'undefined' && typeof process.argv !== 'undefined') {
    if (process.argv.length - 2 === 10) {

        // Features:
        var features = process.argv.slice(2);

        // Parameters:
        const layers = [30, 1];
        const weights = [[[-0.23211521668753296, 9.781501463388121, -0.5184097523554406, 1.3790049916333236, -0.27358550467982407, 0.03732895780818296, -0.24298963854466327, -0.11961231719348318, 2.7097644221517747, 0.030063334352258678, 10.369175001095302, 10.784539626717011, -0.2289000108203549, 0.2928497516987598, -0.36603555578600916, 0.8430195135808076, 5.836754974990878, 0.045454930223875756, -0.2785182210722789, -0.45359888297860085, 0.11062846279981702, 3.6301284665214686, -0.14450188708717862, 0.1489527454284922, 0.9549956153312039, 0.18729623570064166, -0.3213808409001305, -0.38242968683282647, 1.1961771841167308, 4.45532606986758], [-0.31153218012501505, 5.34659651738418, 0.3542776524607793, 0.8960378024707535, 0.14860770694259456, 0.14883741426266303, 0.14444388199720706, 0.2591656334669972, 1.1119349893621226, 0.19373531466755783, 6.2231501274650975, 6.1337079604225995, -0.1700448490860697, 0.22404515464082714, -0.3072991464327124, 0.2518744158256101, 3.623489669164054, -0.15984463974947347, -0.16436676391480332, -0.28713536597775635, 0.6764608002816442, 1.9378490970269495, -0.22334234419957208, -0.18158274646927255, 0.31973301291934925, -0.2259478746400986, 0.057403653669123526, -0.2736754664206129, 0.8867223723354193, 2.3946343120431073], [-1.1148404133812806, 25.141042239126474, -0.4803625170905369, 3.9325137088879143, -0.34855842422526334, -0.14459706505469946, 0.12685799850977264, 0.011531530550160853, 7.95340886457525, 0.06703637462800623, 27.770849035993862, 27.8078401274431, -0.2793783675986404, 0.2380727623464379, -0.07924869365409615, 1.1147101064259195, 15.859132442334063, -0.11790445476179623, 0.19425251273132194, -0.8814487754877787, 0.4586201667770763, 8.818456880441978, 0.19435345324727782, -0.11702735387031568, 1.3423974769698395, 0.031045669337296113, -0.055692954527946315, 0.23777253350060615, 4.434897788131679, 11.410082709197713], [-0.8780639623295481, 15.383043964024461, -0.4920842738859076, 2.440687772003422, -0.07114751109705034, -0.05851124978409383, 0.3124150887209643, 0.05706433319557886, 4.212308946178777, 0.09072805315206878, 16.269383473272555, 16.822273822556415, 0.29890990754629193, -0.11054373937930706, 0.31640809439790124, 1.0883285329821155, 8.900366437594991, 0.33259663577649917, 0.1478485505151078, -0.3738630937364774, 0.21178084993902113, 4.962549468572029, 0.3350426659654491, 0.1524345194769864, 0.49425712390358956, 0.017559213956613817, 0.196625628965266, 0.23980352127057028, 2.853206122456331, 6.565246941426061], [-0.6139837294666064, 7.095390932449621, -0.5546625267882224, 0.9678963783782841, 0.2788395488521895, -0.42056219435743325, 0.040910314768901626, 0.2649009335746527, 1.9595993874181632, -0.17102095420366073, 8.195950917698859, 8.680516167880933, 0.04726754927185238, -0.3728048701803646, 0.2328382536054911, 0.11418474728455252, 4.835152324426195, -0.0868512778858261, 0.28156104848272695, -0.12565334034230116, 0.03151666861681627, 2.3020198793471764, -0.3408411857000827, -0.2932672876021093, 0.018016858018200226, -0.2909025501630191, -0.21243651932233917, 0.12827382377763602, 1.1611908790060375, 2.9759231297789723], [-0.4472639302119193, 7.186298661058716, -0.03777657147481239, 0.8424639418857298, -0.1918222679627078, -0.17751672452275977, -0.23588809255460808, 0.06301201520317282, 2.3440699849565623, 0.2686168862451528, 7.204891768923059, 7.549433814004696, 0.0929050050643861, 0.2547937901686558, -0.26581306542778305, -0.17947184060205218, 3.85649532790901, -0.010575632030441056, 0.082351548610537, -0.0982506948587286, 0.1656764385643189, 2.6971351966286354, 0.06176220773682902, -0.09282996354552223, 0.34338954256792137, 0.1870615565499244, 0.13106989008511805, -0.19960693659793804, 0.5625554690998593, 2.8513000569364677], [0.4061319100058589, -12.682885755709297, 0.43477638030396243, -2.317309259967748, -0.1856345598266541, 0.11404413001442733, -0.23743336215864771, 0.1080116397889521, -3.6820894306069722, 0.32901129828479675, -13.72934365229823, -14.173030927215368, 0.1820572191512012, 0.21080029826285537, 0.3158510019648884, -0.3802035268999079, -8.037880042147444, -0.20573501678110248, 0.09044415718603847, 0.7479973538473657, -0.5887811044255243, -4.228373082861249, 0.3218846967100511, 0.1096422041114381, -0.9351585698159597, 0.03893672493194371, 0.08078786134344858, 0.08468512345384675, -1.596254093599804, -5.167156786593848], [-0.4008071114298885, 12.619085059086633, -0.5022684922058759, 1.6458902504891153, -0.28262906618573536, -0.1260320790108488, -0.3705762508782902, 0.34695031886310684, 3.9035317440643973, -0.3756149740776198, 13.084766388100435, 13.495825309478985, -0.28579079400467916, 0.2396987380625257, -0.12025055798433057, 1.0110431658752315, 7.605716331809788, 0.2934031649782276, 0.2669948196484158, -0.10423086239039181, 0.6191693974271495, 4.24794425665857, 0.23126674178236, -0.16595950100142706, 0.7452200326847983, 0.06682467956284503, -0.37521665574438134, 0.024011217763222344, 1.773531467772138, 5.662378122659779], [-0.8665633849832532, 18.791187618622534, -0.5054252329910491, 2.6699639873702568, 0.22299951421820316, -0.11647475301141844, -0.34549467000313344, -0.061809529914129795, 5.823504613934345, 0.3242046384092999, 19.785797562465817, 21.033322933475088, -0.09558782722748976, 0.366942590353702, 0.08110201016333786, 1.4657405978035332, 11.464318498925845, 0.09919426945984353, -0.1660699206656437, -0.8800550503651883, 0.4063170535079751, 6.729464510863023, 0.19755966894449542, 0.15339418510698774, 1.5409641707033974, -0.17359987416884398, 0.1322749235064741, -0.14765871167830155, 3.031761833397193, 8.372966238928598], [-0.4891702363125708, 19.177633929826037, -0.2284190826124838, 3.020481659009847, 0.366956949141833, 0.08655738869418515, -0.23345586538194335, -0.05676964124448499, 5.536789434916171, 0.23051951972154372, 21.28360461254093, 21.739558708875766, 0.1260255467333044, -0.17797236236795616, -0.19179054705571985, 1.5583008117987633, 11.893388565183344, 0.23402199927190045, 0.05614193427098339, -0.36003637079001133, 0.4889388118835298, 6.809250632387282, 0.053330112657122515, -0.026557499267704242, 1.0374374151118078, -0.27787960218850455, -0.09454701193738549, -0.3881292739014529, 3.6120130717068677, 8.202504470270158]], [[33.485363986226695], [-1662.440069261993], [21.289741914863317], [-122.23337268466445], [0.25099382080737004], [-0.21200040816962099], [-0.1545199825785457], [0.3280845257492277], [-2235.0513858871113], [0.03381873652271508], [-2416.204118784538], [-2953.7608376473886], [0.28716186476178396], [0.31153882095543084], [-0.3530120181800926], [-503.3760612987556], [-1339.3938414318702], [0.0969861478533905], [0.26359139937591053], [41.999481897784236], [-2.342906024552729], [-1512.9766883050115], [-0.21140924669776326], [-0.21372223532210563], [-556.4713009108235], [0.6904053078701193], [0.26092871311699156], [15.016552429443928], [-1631.270537255238], [-3200.9350447987204]]];
        const bias = [[-46.638891566206546, -2244.6660881284865, -36.51265652503184, -363.8590523316024, -0.2626103958203553, -9.25759444987251, -0.11992921348066127, -0.21298313285170783, -584.9019397022495, -0.14541515843225822, -2413.4317511899135, -2443.2749180315323, -0.1881353606559993, -0.3014023024165611, -0.23783004511099995, -110.28109816480968, -1366.127005774272, -0.22603161607469066, -0.19517236644162575, -61.112431717530484, -25.60446300220539, -727.2860604143086, -0.20630132145150193, -0.30831483509783675, -129.80931719981317, -5.955273627903566, -0.2690393876185563, -7.008357242333273, -307.4675649851167, -878.6092553697476], [153.8462754980334]];

        // Prediction:
        var reg = new MLPRegressor('relu', layers, weights, bias);
        var prediction = reg.predict(features);
        console.log(prediction);

    }
}

Run regression in JavaScript


In [5]:
# Save regressor:
# with open('MLPRegressor.js', 'w') as f:
#     f.write(output)

# Run regression:
# if hash node 2/dev/null; then
#     node MLPRegressor.js 0.03 0.05 0.06 0.02 -0.04 -0.03 -0.04 -0.002 0.01 -0.01
# fi