Implementation of SELUs & Dropout-SELUs in NumPy

This looks pretty neat. They can prove that when you slightly modify the ELU activation, your average unit activation goes towards zero mean/unit variance (if the network is deep enough). If they're right, this might make batch norm obsolete, which would be a huge bon to training speeds!

The experiments look convincing, so apparently it even beats BN+ReLU in accuracy... though

I wish they would've shown the resulting distributions of activations after training.

But assuming their fixed point proof is true, it will.

Still, still would've been nice if they'd shown it -- maybe they ran out of space in their appendix ;)

Weirdly, the exact ELU modification they proposed isn't stated explicitly in the paper!

For those wondering, it can be found in the available sourcecode, and looks like this:


In [1]:
# An extra explaination from Reddit
# # Thanks, I will double check the analytical solution. For the numerical one, could you please explain why running the following code results in a value close to 1 rather than 0?
# du = 0.001
# u_old = np.mean(selu(np.random.normal(0,    1, 100000000)))
# u_new = np.mean(selu(np.random.normal(0+du, 1, 100000000)))
# # print (u_new-u_old) / du
# print(u_old, u_new)
# # Now I see your problem: 
# #     You do not consider the effect of the weights. 
# #     From one layer to the next, we have two influences: 
# #         (1) multiplication with weights and 
# #         (2) applying the SELU. 
# #         (1) has a centering and symmetrising effect (draws mean towards zero) and 
# #         (2) has a variance stabilizing effect (draws variance towards 1). 

# #         That is why we use the variables \mu&\omega and \nu&\tau to analyze the both effects.
# # Oh yes, thats true, zero mean weights completely kill the mean. Thanks!

# Tensorflow implementation
import numpy as np

def selu(x):
    alpha = 1.6732632423543772848170429916717
    scale = 1.0507009873554804934193349852946
    return scale * np.where(x>=0.0, x, alpha * (np.exp(x)-1))

In [2]:
# # Tensorflow implementation on github
# def dropout_selu(x, rate, alpha= -1.7580993408473766, fixedPointMean=0.0, fixedPointVar=1.0, 
#                  noise_shape=None, seed=None, name=None, training=False):
#     """Dropout to a value with rescaling."""

#     def dropout_selu_impl(x, rate, alpha, noise_shape, seed, name):
#         keep_prob = 1.0 - rate
#         x = ops.convert_to_tensor(x, name="x")
#         if isinstance(keep_prob, numbers.Real) and not 0 < keep_prob <= 1:
#             raise ValueError("keep_prob must be a scalar tensor or a float in the "
#                                              "range (0, 1], got %g" % keep_prob)
#         keep_prob = ops.convert_to_tensor(keep_prob, dtype=x.dtype, name="keep_prob")
#         keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar())

#         alpha = ops.convert_to_tensor(alpha, dtype=x.dtype, name="alpha")
#         keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar())

#         if tensor_util.constant_value(keep_prob) == 1:
#             return x

#         noise_shape = noise_shape if noise_shape is not None else array_ops.shape(x)
#         random_tensor = keep_prob
#         random_tensor += random_ops.random_uniform(noise_shape, seed=seed, dtype=x.dtype)
#         binary_tensor = math_ops.floor(random_tensor)
#         ret = x * binary_tensor + alpha * (1-binary_tensor)

#         a = tf.sqrt(fixedPointVar / (keep_prob *((1-keep_prob) * tf.pow(alpha-fixedPointMean,2) + fixedPointVar)))

#         b = fixedPointMean - a * (keep_prob * fixedPointMean + (1 - keep_prob) * alpha)
#         ret = a * ret + b
#         ret.set_shape(x.get_shape())
#         return ret

#     with ops.name_scope(name, "dropout", [x]) as name:
#         return utils.smart_cond(training,
#             lambda: dropout_selu_impl(x, rate, alpha, noise_shape, seed, name),
#             lambda: array_ops.identity(x))

In [3]:
# """"""""""""""""""""""""""""""""""""""Dropout to a value with rescaling."""""""""""""""""""""""""""""""""""""""
# NumPy implementation
def dropout_selu(X, p_dropout):
    alpha= -1.7580993408473766 
    fixedPointMean=0.0 
    fixedPointVar=1.0
    keep_prob = 1.0 - p_dropout
#     noise_shape = noise_shape.reshape(X)
    random_tensor = keep_prob
    #         random_tensor += random_ops.random_uniform(noise_shape, seed=seed, dtype=x.dtype)
    random_tensor += np.random.uniform(size=X.shape) # low=0, high=1
    #         binary_tensor = math_ops.floor(random_tensor)
    binary_tensor = np.floor(random_tensor)
    ret = X * binary_tensor + alpha * (1-binary_tensor)

    #         a = tf.sqrt(fixedPointVar / (keep_prob *((1-keep_prob) * tf.pow(alpha-fixedPointMean,2) + fixedPointVar)))
    a = np.sqrt(fixedPointVar / (keep_prob *((1-keep_prob) * ((alpha-fixedPointMean)**2) + fixedPointVar)))

    b = fixedPointMean - a * (keep_prob * fixedPointMean + (1 - keep_prob) * alpha)
    ret = a * ret + b
    #         ret.set_shape(x.get_shape())
    ret = ret.reshape(X.shape)
    return ret

In [5]:
# EDIT: For the fun of it, I ran a quick experiment to see if activations would really stay close to 0/1:
x = np.random.normal(size=(300, 200))
for _ in range(100):
    w = np.random.normal(size=(200, 200), scale=np.sqrt(1/200))  # their initialization scheme
    x = x @ w
    x = selu(x=x)
    x = dropout_selu(X=x, p_dropout=0.10)
    mean = x.mean(axis=1)
    scale = x.std(axis=1) # standard deviation=square-root(variance)
    print(mean.min(), mean.max(), scale.min(), scale.max())


-0.223099757697 0.203214936577 0.863587755542 1.20470961504
-0.21691843908 0.215218979966 0.86360441051 1.17613900672
-0.161349395876 0.184066776651 0.865182654385 1.15937043217
-0.180408397624 0.19395959507 0.855662040461 1.13162258461
-0.220800799778 0.175262347506 0.821130427955 1.14018506117
-0.212409677402 0.206952594448 0.86887014689 1.19041961646
-0.16692049765 0.22821851292 0.870497914974 1.15283699163
-0.184918758795 0.256431444338 0.849269702629 1.14055418877
-0.207601510467 0.210564130481 0.879390880617 1.19656049533
-0.152709286671 0.232398634993 0.841774765762 1.15741897959
-0.173832959797 0.197930028588 0.849512582335 1.15532061272
-0.179813805349 0.180460649534 0.85772822303 1.13592202672
-0.180581762658 0.218309226162 0.839830731126 1.12568685287
-0.183970638364 0.194683101424 0.860145065655 1.16212083045
-0.20736713487 0.180152402687 0.864656431083 1.12264923507
-0.212344034816 0.225837809064 0.823923340602 1.14795471346
-0.245283079866 0.172039187098 0.854419442815 1.11906075925
-0.214533196755 0.200670858448 0.862304694874 1.15210663715
-0.177373436741 0.223944634595 0.851469433738 1.1432158645
-0.183395444936 0.210861328708 0.871093117738 1.13390114973
-0.166750398512 0.174761568831 0.856371703019 1.18376183399
-0.252384855475 0.20149725534 0.843370389574 1.15145011579
-0.221711966822 0.217442014352 0.838157922718 1.1424177135
-0.184348275379 0.247400750231 0.861331907094 1.16140959797
-0.234254292997 0.22977059251 0.852114337017 1.12626103429
-0.186098098745 0.21160714143 0.849900940331 1.12065880329
-0.170507602544 0.237435405815 0.84229094269 1.11735457476
-0.195856093174 0.186042646631 0.860172703791 1.13518176372
-0.181051413967 0.176787535416 0.876350522595 1.15705973605
-0.193819140466 0.192010325127 0.869278136043 1.12814709302
-0.206266458657 0.175250527974 0.832422889852 1.14328886939
-0.191495837593 0.183954675458 0.867421099812 1.12574300226
-0.207411250114 0.223207065849 0.85387997104 1.13343540478
-0.236732359651 0.208067457382 0.860207618329 1.15283240363
-0.175806197957 0.222663742065 0.874324101433 1.17841600194
-0.228324773204 0.242619622 0.85088535356 1.18396021668
-0.198568708982 0.256923884134 0.858579446772 1.20841305346
-0.205708185447 0.190704757006 0.867570017222 1.21556740104
-0.162400627058 0.223029528421 0.870322394892 1.13964731111
-0.174020716589 0.185036624145 0.848110497393 1.16655839129
-0.199527880554 0.217998130668 0.854321250849 1.20109965685
-0.189255179806 0.168837340366 0.843631748525 1.18277840795
-0.215548545311 0.225834701237 0.862987134389 1.20404020641
-0.176977603593 0.213382207324 0.842150783767 1.16116112338
-0.274021380425 0.259382436503 0.864187419513 1.22432638672
-0.161038141841 0.244330738677 0.830768484959 1.19116563839
-0.194407722159 0.207948998981 0.851060536714 1.17472488798
-0.175939543785 0.153791563302 0.831836272293 1.13589753709
-0.247249692257 0.161875977159 0.805711472173 1.18113753344
-0.176993282118 0.208521452192 0.868927338507 1.12609203986
-0.182454943159 0.200136119047 0.845007452867 1.14424239688
-0.179728659899 0.162855352271 0.874940294031 1.15578337206
-0.178642403484 0.19344249463 0.867033170275 1.14169269205
-0.264731275562 0.263238645706 0.863886581017 1.14459080243
-0.207352265347 0.197101110599 0.83178805335 1.16725482498
-0.205900546022 0.212949594799 0.848467585635 1.13340711152
-0.166140936899 0.20774825572 0.868965278339 1.16192000514
-0.177495155402 0.211962658046 0.838666576387 1.1733414143
-0.195584349549 0.201033313031 0.780040468684 1.18588998601
-0.240702451037 0.251688533001 0.866747677701 1.22475545298
-0.233035005923 0.198427760627 0.873528205218 1.16552074957
-0.223315682682 0.150691414981 0.865345481284 1.15004875654
-0.239142115045 0.187175872425 0.846284550412 1.1592772991
-0.230948064806 0.169995913675 0.857380608892 1.15915991791
-0.201141346253 0.178584627401 0.854870144563 1.15378083919
-0.278785102027 0.191935140718 0.866172028982 1.16079801541
-0.186210844675 0.203896849618 0.862530436681 1.16033779565
-0.199406276322 0.266687203949 0.864771654541 1.15015966498
-0.197901474381 0.167936841076 0.864306669567 1.15444393565
-0.19087498731 0.156564467914 0.859340146508 1.13187118144
-0.218658029491 0.200498813327 0.854370420607 1.12528145657
-0.225181238287 0.170009546187 0.847750494626 1.17538549217
-0.183179261977 0.171034485281 0.854596593801 1.17071678384
-0.203923763565 0.236517446082 0.868238246043 1.13737296948
-0.216448134492 0.232148335898 0.849609654803 1.14282778836
-0.203234638853 0.186466624543 0.883916117161 1.1572654249
-0.18180434917 0.196021755092 0.846549762138 1.15204124448
-0.174665825398 0.201034713156 0.854886868684 1.1683025203
-0.162223246661 0.210811539917 0.864975080831 1.14194986197
-0.156284013717 0.19047399415 0.837221122196 1.19135724107
-0.14571706209 0.169419568106 0.880637391716 1.18151231198
-0.162724552536 0.180999780637 0.840717285421 1.18241486758
-0.174777821917 0.156889478436 0.868163183909 1.15381975246
-0.184477876758 0.217959368287 0.869219405766 1.16452181314
-0.187077811411 0.300878642474 0.805547334947 1.17857545956
-0.218075406424 0.178849068992 0.835915029557 1.15031374203
-0.197318453356 0.20173317268 0.878213835508 1.12332307434
-0.218567501829 0.197974987703 0.866352053846 1.114650459
-0.229031963215 0.190381006319 0.842304486256 1.15640290387
-0.199987564807 0.186073203949 0.871355048247 1.14895110387
-0.229188464856 0.192649918465 0.85541092699 1.14070672124
-0.201769641258 0.275616310374 0.845415805046 1.14264097357
-0.207183782702 0.196146073795 0.872659912835 1.14620941067
-0.186768806403 0.257813880245 0.876183962581 1.15180820435
-0.232866190609 0.224958356535 0.864849265412 1.13239780912
-0.238300805211 0.203844115464 0.867146286664 1.1276327455
-0.226168479899 0.231569293343 0.875448494684 1.16459279586
-0.198093891043 0.159475158231 0.823739256475 1.15296189167
-0.204697999239 0.195255110504 0.843554780537 1.14441001338
-0.179649468635 0.211114376556 0.868685052054 1.1415210694

In [6]:
# My NumPy implemetation of Normal dropout for ReLU
def dropout_forward(X, p_dropout):
    u = np.random.binomial(1, p_dropout, size=X.shape) / p_dropout
    out = X * u
    cache = u
    return out, cache

def dropout_backward(dout, cache):
    dX = dout * cache
    return dX

In [7]:
# EDIT: For the fun of it, I ran a quick experiment to see if activations would really stay close to 0/1:
x = np.random.normal(size=(300, 200))
for _ in range(100):
    w = np.random.normal(size=(200, 200), scale=np.sqrt(1/200))  # their initialization scheme
    x = x @ w
    x = selu(x)
    x, _ = dropout_forward(p_dropout=0.8, X=x)
    mean = x.mean(axis=1)
    scale = x.std(axis=1) # standard deviation=square-root(variance)
    print(mean.min(), mean.max(), scale.min(), scale.max())


-0.231921552947 0.205240271306 0.959559397803 1.30683997472
-0.232878572258 0.244668575343 0.956916861795 1.43996258687
-0.244880507301 0.400038118486 0.998917067879 1.58661165578
-0.21619381187 0.339647741092 1.06985549361 1.68672987809
-0.204633155259 0.376022587439 1.05605942671 1.76512717898
-0.157701351545 0.393720624242 1.09954887384 1.85568666575
-0.155484269348 0.466262555265 1.13862266436 1.92573697266
-0.111770021666 0.469575897379 1.21037699743 2.124692154
-0.139937132059 0.490713331488 1.1721806166 2.11632850434
-0.217489684692 0.545963625836 1.15222184507 2.09077813307
-0.238708720474 0.569139391448 1.2111838962 2.10146437294
-0.211100122885 0.439907507471 1.23846479411 2.13017176194
-0.172591422552 0.529197632632 1.29130128106 2.30888912185
-0.257952512643 0.48331127522 1.30839150688 2.18686300781
-0.140180897111 0.55930176637 1.2910108532 2.26804240326
-0.207817976177 0.637984950226 1.19657196518 2.26216848596
-0.168422932788 0.597874378031 1.27361030041 2.3057602796
-0.137712104751 0.520977037006 1.25261746179 2.25324716721
-0.244672071397 0.501987298336 1.31841859022 2.14718702716
-0.198404907682 0.45921260366 1.34125196309 2.14904143669
-0.15269025403 0.651377878149 1.31025577186 2.25858709949
-0.13054755745 0.540751518148 1.30258717414 2.25841152636
-0.149940276302 0.483470663062 1.29513211513 2.23740448378
-0.119293727876 0.658633293281 1.24793805395 2.20108909035
-0.15668064803 0.593294443206 1.3074887817 2.23839613718
-0.120868276544 0.536212916257 1.31393331364 2.2966498815
-0.19113526684 0.552962266824 1.31240138479 2.21096451851
-0.156267821983 0.486342248572 1.32587524421 2.16473041244
-0.100235713142 0.501872234788 1.30004240668 2.1245499435
-0.215577883733 0.558059874424 1.24446091342 2.09742147583
-0.151220989861 0.513083724799 1.30188322034 2.13301478804
-0.150931883774 0.474825205687 1.28243506356 2.27040417428
-0.130117439733 0.571291880771 1.25732431701 2.21258248452
-0.21197832079 0.534264184275 1.27995967173 2.29343678968
-0.191943681882 0.470310155038 1.20411146637 2.15783972817
-0.156526715717 0.514564844177 1.23850957704 2.25976082645
-0.115004789649 0.486224466101 1.25174301414 2.09629537713
-0.183367258971 0.498513712753 1.27751129641 2.13292069986
-0.251735543734 0.527095585082 1.32338049434 2.22882288128
-0.142890512126 0.537222185107 1.29313149198 2.32764496683
-0.2204717533 0.604634995581 1.24751726514 2.26554010981
-0.196289718148 0.49668251988 1.18319889855 2.12554702282
-0.227830073458 0.442901288354 1.25786026831 2.24539148359
-0.169100154683 0.478516054257 1.32345692467 2.32660268297
-0.203810871545 0.693838422017 1.24578247204 2.26425617451
-0.1810764887 0.653014882553 1.27003020597 2.27709504605
-0.156060269717 0.495710684416 1.24837219857 2.16789043693
-0.188543976644 0.616475420305 1.301334088 2.24868280078
-0.174163884487 0.580005789536 1.26349025194 2.12587653732
-0.192533249111 0.699661889303 1.21882538908 2.11465094624
-0.177646162155 0.490735578896 1.30935173748 2.17174884646
-0.159634874634 0.485871700578 1.30468791475 2.12105800449
-0.193893405937 0.534731309052 1.28248554337 2.20213578945
-0.0732053795571 0.513070822044 1.24138322161 2.16048103727
-0.234102764336 0.463710150559 1.26599163142 2.13406730848
-0.164139632378 0.528122312311 1.28960286563 2.1753945805
-0.279578377228 0.691091848205 1.23199645377 2.17303283002
-0.255770040167 0.652419012407 1.25869600516 2.27549674573
-0.162198871127 0.659767150474 1.20988594055 2.18295838099
-0.255285788906 0.577120462096 1.28035067162 2.240415332
-0.211088016872 0.648110612244 1.30214395499 2.2474329407
-0.181460166701 0.54225868796 1.18581848823 2.14821759831
-0.193835513786 0.537578392014 1.30837981355 2.37544205701
-0.14297542186 0.524390224997 1.28502063584 2.27112927394
-0.140122614759 0.427371415377 1.26239524247 2.24411609376
-0.101553434351 0.535912759619 1.3020965274 2.19117889322
-0.188448906002 0.432695136584 1.28404226362 2.08953585259
-0.14849882729 0.522533741685 1.29634375613 2.04693745109
-0.177545815531 0.52171677062 1.24115487913 2.07659659268
-0.173130261063 0.503646261663 1.28442484971 2.17595419658
-0.144965098073 0.546067719285 1.28954611428 2.08696461088
-0.16740879377 0.4873644993 1.31893144728 2.17707918932
-0.134102945697 0.608891099153 1.30106093571 2.04850257368
-0.113528198173 0.625989170776 1.24855413036 2.08316682848
-0.1294751143 0.525761700236 1.26131010412 2.2874853649
-0.110109491689 0.539435933625 1.18001360556 2.26855909627
-0.160292591764 0.583602028259 1.17662707472 2.32415980979
-0.187473422336 0.577572024425 1.25404676056 2.28481203957
-0.146829772123 0.533221916155 1.30349445787 2.14922543821
-0.144979936519 0.654075972527 1.28359381833 2.14806158649
-0.142151684912 0.568919983805 1.2641604538 2.13314734075
-0.196541817647 0.636573386038 1.27225217548 2.1877467682
-0.101967482394 0.563821964524 1.30012925111 2.2833281028
-0.229796978962 0.643465005773 1.29166475842 2.32907792581
-0.175879573115 0.67722380654 1.32589282279 2.34613026676
-0.138692421095 0.487577391379 1.25328287623 2.29837566391
-0.20422426055 0.552415065155 1.33885200471 2.28679636742
-0.2292496433 0.538642486094 1.21377606928 2.29379226715
-0.183541705342 0.534063934405 1.32770224084 2.27239830938
-0.263299098567 0.592474237922 1.32373473774 2.3216829787
-0.198493089445 0.531723000953 1.32548848992 2.2989313628
-0.248517146878 0.640703333777 1.25956927917 2.23301863921
-0.200580403456 0.593550897715 1.30087838385 2.42208239484
-0.195335292293 0.554079006129 1.32459357056 2.36554920002
-0.135622309833 0.59298576476 1.28322186547 2.39302870105
-0.123539492804 0.610336440315 1.22596463097 2.43403888415
-0.0611041120411 0.518569055017 1.271566511 2.28450384828
-0.14337746928 0.645191691113 1.3610470694 2.33067328787
-0.136433976725 0.581199166658 1.28216111909 2.30439983808
-0.170143914489 0.535496162765 1.26387516842 2.12797083256

In [6]:
def elu_fwd(X):
    alpha = 1.0
    scale = 1.0
    #     return scale * np.where(x>=0.0, x, alpha * (np.exp(x)-1))
    X_pos = np.maximum(0.0, X) # ReLU
    X_neg = np.minimum(X, 0.0) # otherwise: if X<=0, Exp Leaky ReLU
    X_neg_exp = alpha * (np.exp(X_neg)-1) # a: slope, a>=0
    out = scale * (X_pos + X_neg_exp)
    cache = (scale, alpha, X) # mean=0, std=1
    return out, cache

def elu_bwd(dout, cache):
    scale, alpha, X = cache # mean=0, std=1
    dout = dout * scale
    dX_neg = dout.copy()
    dX_neg[X>0] = 0
    X_neg = np.minimum(X, 0) # otherwise: if X<=0, Exp Leaky ReLU
    dX_neg = dX_neg * alpha * np.exp(X_neg) # derivative of abs(np.exp(X_neg)-1) # a: slope, a>=0
    dX_pos = dout.copy()
    dX_pos[X<0] = 0
    dX_pos = dX_pos * 1
    dX = dX_neg + dX_pos
    return dX

In [7]:
# EDIT: For the fun of it, I ran a quick experiment to see if activations would really stay close to 0/1:
x = np.random.normal(size=(300, 200))
for _ in range(100):
    w = np.random.normal(size=(200, 200), scale=np.sqrt(1/200))  # their initialization scheme
    x = x @ w
    x, _ = elu_fwd(X=x)
    x, _ = dropout_forward(p_dropout=0.95, X=x)
    mean = x.mean(axis=1)
    scale = x.std(axis=1) # standard deviation=square-root(variance)
    print(mean.min(), mean.max(), scale.min(), scale.max())


0.0267840597416 0.34870387555 0.688048836569 1.04717887996
0.0118506955263 0.304876451832 0.56287559608 0.883068359607
-0.00848144962242 0.206902931306 0.461438275839 0.773605957189
-0.0135542733853 0.188593823674 0.382063563973 0.688186843237
-0.0561000469178 0.122238741339 0.349825172164 0.602874984924
-0.0343566384099 0.18028897461 0.31253884688 0.548510582005
-0.0429958073188 0.116054084819 0.258804537482 0.468348512243
-0.0365803448673 0.0872113156711 0.24577948731 0.438054082984
-0.0599072152755 0.107624519345 0.215066658884 0.422357677398
-0.0350411695585 0.0776162947341 0.205136251217 0.425153568286
-0.0283457218838 0.0652535851403 0.185845104457 0.385896334075
-0.0272042989458 0.0766230668132 0.17945437564 0.343158539221
-0.0265363719492 0.0529204998045 0.175905477592 0.337747682314
-0.0195131277272 0.050982043472 0.157488026364 0.320393452692
-0.0279429986826 0.0494597719809 0.148463686617 0.302291903009
-0.0163050932564 0.0395091568666 0.134573867209 0.275776177388
-0.0350636247315 0.0446047271227 0.130681932638 0.260456518769
-0.0280860610856 0.0521212852515 0.127825010424 0.242927405648
-0.031611045363 0.0448226140599 0.119061552083 0.247784549674
-0.0160116482817 0.0350665171291 0.110955975843 0.230000436935
-0.032908915089 0.0455961663955 0.106944810848 0.221535694601
-0.0255515030656 0.0533079790241 0.101583851262 0.218571710404
-0.0287315200628 0.0505785025543 0.100373477838 0.211213550803
-0.0218860512348 0.0395445699361 0.09912119378 0.216332757744
-0.0155101553201 0.0300637152255 0.0940974816116 0.217682172512
-0.0278378861071 0.0382129355701 0.0871000120494 0.213244116615
-0.0166008582174 0.0282495426849 0.0856275861282 0.201773089435
-0.0241697863 0.0411436133239 0.0820385581023 0.206187616881
-0.0157925180858 0.0296287485956 0.0816271592875 0.202022137338
-0.0131223551871 0.0344921217048 0.0810812233719 0.189484668356
-0.0178672551798 0.0196992522479 0.0769284293791 0.190846401789
-0.015907970919 0.0340555629548 0.0739359328161 0.195642685556
-0.0167867545129 0.028162573988 0.0688426757733 0.185973613894
-0.0198006002858 0.0291724924825 0.0722614785159 0.170296669289
-0.0223565323619 0.0271525026234 0.0676686938757 0.165745072873
-0.0129344926851 0.0326925108362 0.0674981892731 0.160179767947
-0.020047396402 0.0244167398387 0.0696481896366 0.154583522557
-0.0354766110678 0.0286733638038 0.0646655802468 0.151386817687
-0.0166541157601 0.019886945775 0.0650172585953 0.15715925585
-0.0169269825833 0.0263314959328 0.0612392050587 0.148117722604
-0.0208108792175 0.0337166211252 0.0567820375857 0.157970273241
-0.0175155554353 0.0194623862136 0.050853477726 0.148709289169
-0.0163136072795 0.0194587848715 0.0528851158249 0.154573862383
-0.0224563452682 0.0330225583336 0.054727681792 0.155185488264
-0.0148554864396 0.0232484025477 0.0526528855982 0.155948090419
-0.015407728749 0.0226526584702 0.0503735663922 0.161205399028
-0.0198421090144 0.0295908023975 0.0545698799751 0.166946872485
-0.0101000184353 0.0164851423265 0.0544275861449 0.17017293538
-0.0183716129586 0.0214069727876 0.0521637561451 0.1592011914
-0.0164162385149 0.0236446085742 0.0512908036539 0.165370371683
-0.0124907074268 0.0173635905341 0.0468553304154 0.155802368363
-0.0264065451783 0.0237436173671 0.0495317528297 0.144526511002
-0.0170799485706 0.0211716388198 0.0497479568515 0.146070335122
-0.0126531275369 0.0214300433465 0.0492852147178 0.144859578639
-0.0096736574645 0.017352726812 0.0518772176472 0.134865122738
-0.0161867091492 0.0147488788412 0.0506882555431 0.132496135044
-0.0214809800781 0.0205734417392 0.0512116767513 0.13908018094
-0.016115397592 0.025904050316 0.046993164924 0.129877520715
-0.0168817140433 0.0189450076508 0.0447395484776 0.127975415041
-0.0126701066329 0.0180751524404 0.0417024671629 0.125152155025
-0.0144218596113 0.0178510686842 0.0409738874511 0.129313206402
-0.0103570746363 0.0142001515875 0.0416981939267 0.129968687474
-0.0123291748387 0.0162648937274 0.0406117482539 0.124972189775
-0.015346694443 0.0252470819536 0.0382869204296 0.133190426772
-0.0140690491473 0.0143299470095 0.0359749590823 0.12737939207
-0.0140148315276 0.0145609375856 0.0360992892364 0.121716602723
-0.0198422635265 0.0202262833543 0.0354968258464 0.130395775452
-0.00848609147717 0.0157402132932 0.0363170194672 0.127288818346
-0.0118144287768 0.0155386830612 0.0386537046769 0.127784844402
-0.0125379898886 0.0172792918855 0.0376219436861 0.127249539625
-0.0148385557007 0.0163772349032 0.0367424337849 0.129384211941
-0.012194784855 0.0190643416517 0.0338767916465 0.1170230661
-0.0150463114264 0.0153197907742 0.0352578702687 0.117922457118
-0.0200746660961 0.0196019318234 0.0340021113441 0.117798748676
-0.00999526481601 0.0168783081618 0.0363771894882 0.114587608924
-0.0107125452314 0.0155062187427 0.0375831283136 0.11133651553
-0.0149699987639 0.0196617516231 0.0361636506983 0.111645390051
-0.0130906496703 0.0206873098424 0.0354054042945 0.113875946051
-0.0157400081638 0.0198031247077 0.0343283491918 0.119985952594
-0.0165416418283 0.0239340061497 0.0321229110038 0.121945496186
-0.0181071892377 0.018167803194 0.0335861073213 0.126017344986
-0.013953680237 0.0172885594547 0.0343894257796 0.125344627025
-0.0186229709036 0.0207122486406 0.0353495796649 0.12250358027
-0.00878233107791 0.0165026205361 0.0380215242646 0.129686818452
-0.0134546068636 0.0146227046556 0.0365660470591 0.129589844885
-0.0111243095259 0.0140677223193 0.0372159083678 0.129993633691
-0.0150002017929 0.0234591629626 0.0388439076028 0.129830324013
-0.012063932092 0.0165148835715 0.0383054163978 0.124212667474
-0.0104740181877 0.0221637545266 0.036589893577 0.12121861639
-0.0117376051638 0.0182647042151 0.0354353912801 0.123536789701
-0.0128660240953 0.0164226391203 0.0356750709574 0.13046883613
-0.0104825693764 0.0208843272362 0.0349852664605 0.122620847611
-0.00977902744701 0.0186812053538 0.0324718208798 0.12684655601
-0.0107864513671 0.0151510064738 0.0350752961548 0.12362225557
-0.0159308759 0.0168429393574 0.0387256721004 0.125421593372
-0.0146330148658 0.0147085934465 0.0390732010775 0.130182556005
-0.0245935641368 0.0273270660344 0.0382575769433 0.130610395861
-0.0128806106158 0.0190988912682 0.038211902248 0.12852023976
-0.0120070545454 0.0149455463847 0.0392919992364 0.135765570327
-0.0104428986723 0.0159663312185 0.0406177137163 0.133933685466

In [13]:
def selu_fwd(X):
    alpha = 1.6732632423543772848170429916717
    scale = 1.0507009873554804934193349852946
    #     return scale * np.where(x>=0.0, x, alpha * (np.exp(x)-1))
    X_pos = np.maximum(0.0, X) # ReLU
    X_neg = np.minimum(X, 0.0) # otherwise: if X<=0, Exp Leaky ReLU
    X_neg_exp = alpha * (np.exp(X_neg)-1) # a: slope, a>=0
    out = scale * (X_pos + X_neg_exp)
    cache = (scale, alpha, X) # mean=0, std=1
    return out, cache

def selu_bwd(dout, cache):
    scale, alpha, X = cache # mean=0, std=1
    dout = dout * scale
    dX_neg = dout.copy()
    dX_neg[X>0] = 0
    X_neg = np.minimum(X, 0) # otherwise: if X<=0, Exp Leaky ReLU
    dX_neg = dX_neg * alpha * np.exp(X_neg) # derivative of abs(np.exp(X_neg)-1) # a: slope, a>=0
    dX_pos = dout.copy()
    dX_pos[X<0] = 0
    dX_pos = dX_pos * 1
    dX = dX_neg + dX_pos
    return dX

# def dropout_selu_forward(X, p_dropout):
def dropout_selu_forward(X, keep_prob):
    alpha= -1.7580993408473766
    fixedPointMean=0.0
    fixedPointVar=1.0

    u = np.random.binomial(1, keep_prob, size=X.shape) / keep_prob
    out = X * u + alpha * (1-u)

    #     a = tf.sqrt(fixedPointVar / (keep_prob *((1-keep_prob) * tf.pow(alpha-fixedPointMean,2) + fixedPointVar)))
    a = np.sqrt(fixedPointVar / (keep_prob *((1-keep_prob) * (alpha-fixedPointMean)**2 + fixedPointVar)))
    b = fixedPointMean - a * (keep_prob * fixedPointMean + (1 - keep_prob) * alpha)
    out = a * out + b
    cache = a, u
    return out, cache

def dropout_selu_backward(dout, cache):
    a, u = cache
    dout = dout * a
    dX = dout * u
    return dX

In [14]:
# EDIT: For the fun of it, I ran a quick experiment to see if activations would really stay close to 0/1:
x = np.random.normal(size=(300, 200))
for _ in range(100):
    w = np.random.normal(size=(200, 200), scale=np.sqrt(1/200))  # their initialization scheme
    x = x @ w
    x, cache = selu_fwd(x)
    x, _ = dropout_selu_forward(keep_prob=0.95, X=x)
    mean = x.mean(axis=1)
    scale = x.std(axis=1) # standard deviation=square-root(variance)
    print(mean.min(), mean.max(), scale.min(), scale.max())


-0.115228899805 0.273307640366 0.832187931633 1.20092983661
-0.11342776269 0.297433766906 0.890738815289 1.25180513542
-0.12925098648 0.362425379672 0.93295919605 1.3384718603
-0.12964971994 0.366295910839 0.966780089222 1.40596996607
-0.0864575866013 0.290687936815 0.959321646025 1.33767307621
-0.111998627179 0.300401219852 0.951157571282 1.35759226264
-0.0878051364106 0.313732327559 0.985100281202 1.39437359224
-0.102829064867 0.313342590658 0.967987306932 1.41533396008
-0.134024418009 0.514400568221 1.0022174775 1.42117544315
-0.156572603248 0.331289658482 1.00391056547 1.41748889863
-0.164599296478 0.300678860193 0.987115048323 1.39145024166
-0.0764760393873 0.297173187184 0.996668365766 1.37470533887
-0.126296878605 0.326807280604 0.978965605922 1.37800377186
-0.0872685363087 0.326317075429 0.987838245488 1.40211446152
-0.118293785738 0.419532184593 0.958342284191 1.3946407666
-0.0860908759643 0.436037876788 0.961645016304 1.44609169824
-0.0636222554284 0.386101182923 0.998806334399 1.44759920003
-0.181929225149 0.335378265442 0.999062539032 1.50121548239
-0.0895673207494 0.334585065123 0.978050127889 1.44168677893
-0.129702584252 0.410672023252 0.979377827597 1.45735925441
-0.239788123087 0.397520767005 0.945265370174 1.42983578513
-0.114522432621 0.400078782696 0.979690293053 1.47149783274
-0.12286204896 0.29293371492 0.9473377336 1.41184702522
-0.101474229466 0.369619684188 0.999720958148 1.3810675652
-0.0875088160167 0.356033207916 0.996492860525 1.38997988313
-0.0946157989033 0.436864069074 0.970902699465 1.43291050902
-0.0640008116266 0.430058782575 1.00581954935 1.52568786306
-0.151494854851 0.393074369122 0.983681101082 1.46282286595
-0.168454054018 0.289481786372 0.979100203733 1.45568830527
-0.185721044401 0.419578957749 1.00562503129 1.40983235576
-0.0779894210024 0.376220829132 0.986582581946 1.39545711609
-0.0717489161858 0.401662282137 1.00267600521 1.43037733691
-0.0834211478815 0.407552124968 0.976197123755 1.44592490008
-0.124766901476 0.321803063974 0.997203906976 1.3824228762
-0.0726854315395 0.373692391771 0.993444989714 1.39067625825
-0.0889516106906 0.350439422222 0.963004067953 1.44205886941
-0.119535482336 0.377482929881 1.00382149059 1.40373522689
-0.162243524847 0.405572110471 0.972568962118 1.44796505305
-0.115348812343 0.333949912358 0.944730788459 1.42005855386
-0.221421289859 0.39763314378 0.982301036853 1.4116961299
-0.057044608757 0.336705596383 0.986268499542 1.40643411308
-0.123094511023 0.415682596369 0.954767297054 1.35504161193
-0.0759744768879 0.363510539801 0.957621005938 1.37455651238
-0.101645481764 0.354558410155 0.950494082756 1.43273158666
-0.127625310522 0.410104436223 0.993055772671 1.46136881227
-0.0356890630054 0.362543198441 0.978712762247 1.49758591572
-0.0599886813281 0.323979735209 0.973233916807 1.48434763641
-0.124792699341 0.329145019092 0.970251737096 1.41189716712
-0.190615343296 0.429466256457 0.978073987919 1.41863357721
-0.153923626409 0.350039127356 0.986090950284 1.42331658194
-0.16435022628 0.384479650199 1.01247912416 1.46547119279
-0.127714259362 0.350174670831 0.975322522342 1.55295707261
-0.0857134363747 0.410998524197 0.955041865151 1.48983922539
-0.11026119626 0.414133292738 0.96018444356 1.42920560989
-0.11236231239 0.338400084054 1.03160299222 1.40501196495
-0.144120813702 0.388965592842 0.995790919483 1.42095502727
-0.063860892898 0.482335056791 0.997008399852 1.45516344289
-0.124909849825 0.327534844585 1.01094127231 1.47281489121
-0.0597944336601 0.35684589467 1.01157819635 1.41768397797
-0.116424475879 0.370545889082 0.978530051119 1.38158405192
-0.12604576844 0.386974720324 1.01124855551 1.41705015758
-0.125046603261 0.33972502754 1.01096352395 1.37476548722
-0.0635720714009 0.303457290749 1.00467406171 1.37744868269
-0.140896604417 0.351730765134 0.966295356459 1.40981267255
-0.0764839727102 0.364760998051 0.959241988556 1.39752390746
-0.145541537066 0.311771818282 1.00320481672 1.44410827978
-0.079181448015 0.320372685645 1.01196207684 1.39729720958
-0.119191486848 0.307728554891 1.02275844026 1.41402569739
-0.0760905141137 0.345901578886 0.988140631682 1.442634556
-0.170817274991 0.424273709812 1.01878146756 1.46326403069
-0.0479612435572 0.398387422933 0.997689244216 1.49310871523
-0.104832570958 0.370683743693 0.979499043648 1.37027316861
-0.180111593914 0.397108854986 0.958167538423 1.4284760874
-0.102975045489 0.315020834426 0.973338610799 1.37137289664
-0.0621111553628 0.382668641556 0.969211583555 1.43263462232
-0.0880530108861 0.356197030098 0.991121917719 1.44426625451
-0.114691033081 0.354841433295 1.03363527697 1.41606525611
-0.0587464766113 0.443235237822 1.01210956955 1.53606060258
-0.13496297684 0.39412358506 1.01902801234 1.5063958555
-0.0982104874004 0.457533567955 0.979983742039 1.369184593
-0.0709791664788 0.330664605095 0.94361573412 1.42976467017
-0.116435140097 0.397192428061 1.00294462195 1.43266672531
-0.11638472554 0.363163916631 0.953032293025 1.37745969884
-0.103411068685 0.364929661599 1.01612552216 1.38531070911
-0.10960539275 0.399266531849 1.01512582848 1.44068868663
-0.109631165564 0.405557520374 0.944475397744 1.44824900446
-0.123586080421 0.391352404912 0.977718978685 1.40683731901
-0.147916275529 0.38109901041 0.992910014363 1.3948313526
-0.0676123481777 0.371553527035 0.993930322772 1.46712264129
-0.115854620798 0.365832390649 1.00550158714 1.44763375719
-0.0250868837277 0.392841071737 0.989923904154 1.47457954248
-0.174018098769 0.329962278997 1.0131579859 1.44885163387
-0.113631702155 0.39350645338 0.99525406775 1.47566704784
-0.0874180205547 0.343342975009 1.00699174476 1.3995414103
-0.133145950033 0.352128285507 0.975793358161 1.43749520742
-0.156125420521 0.37358808285 0.978818264058 1.42430662335
-0.0350291961047 0.376470230234 0.955409941352 1.41604863003
-0.151570195217 0.386990408432 0.978749655746 1.47324065753
-0.160330897662 0.361891592797 0.963041843429 1.5023204259
-0.1048696337 0.343512187342 0.967033160237 1.40707302618

In [18]:
# EDIT: For the fun of it, I ran a quick experiment to see if activations would really stay close to 0/1:
x = np.random.normal(size=(300, 200))
for _ in range(100):
    w = np.random.normal(size=(200, 200), scale=np.sqrt(1/200))  # their initialization scheme
    x = x @ w
    x, cache = selu_fwd(x)
    x = dropout_selu(X=x, p_dropout=0.10)
    mean = x.mean(axis=1)
    scale = x.std(axis=1) # standard deviation=square-root(variance)
    print(mean.min(), mean.max(), scale.min(), scale.max())


-0.235271254088 0.218746820232 0.819007812136 1.11747800096
-0.20951946942 0.180946004293 0.834080921638 1.13113481879
-0.191676850317 0.224930764106 0.868843915156 1.13158722935
-0.176701747105 0.157517304949 0.852585573073 1.12639938655
-0.162840032431 0.186669002989 0.858098346654 1.18314823773
-0.220704901111 0.190378608271 0.833024054349 1.14342284633
-0.19923882051 0.211312399024 0.87131817578 1.121972573
-0.210321754749 0.237660213544 0.871612931827 1.11061683783
-0.185461154997 0.225793126856 0.862512256897 1.1125366835
-0.178912312266 0.155551393681 0.852681937939 1.16124534451
-0.208812346507 0.179387508362 0.84954272907 1.16233918417
-0.227677609213 0.203579291629 0.853329180904 1.13222261062
-0.177117268663 0.182170079468 0.870238280515 1.21209197618
-0.187281710031 0.206016154254 0.866706826485 1.24963262405
-0.18972440741 0.171150605527 0.839788078575 1.16848345882
-0.187734315865 0.204341560836 0.850645039618 1.14832907681
-0.280902837708 0.170151586913 0.888266070211 1.11742795259
-0.220330882606 0.14259808742 0.880024049717 1.15157886524
-0.226636394136 0.23990538548 0.880766720859 1.17482291552
-0.189821889697 0.197320949643 0.843138492515 1.15156231693
-0.205873736051 0.193131118624 0.875224317695 1.15097064375
-0.175492244751 0.238941788248 0.834642009854 1.14497440375
-0.193472005181 0.212400216644 0.831021648919 1.17721384286
-0.246634064447 0.247270559311 0.8547008918 1.16955789493
-0.191233743558 0.144080385902 0.869282115449 1.1375256101
-0.163136827169 0.208774959456 0.825411819344 1.15468653809
-0.218951542539 0.198246038325 0.86545646679 1.13294822955
-0.216843598195 0.1841587331 0.854876706012 1.14539093729
-0.184325132298 0.18782000645 0.877681698758 1.14607753804
-0.199951278206 0.197911458732 0.847665010603 1.22978029564
-0.206654581907 0.195201978252 0.878407308985 1.15406802818
-0.20437491412 0.191322549539 0.858402436328 1.14076194107
-0.20558495464 0.185967871578 0.862151182605 1.15155434123
-0.186210928388 0.227452088474 0.867872081867 1.13965760059
-0.180288625717 0.184311261643 0.87015350107 1.12869123966
-0.232211626681 0.19854605133 0.840691609251 1.16148837842
-0.218985205811 0.221109398767 0.846810926425 1.16182495504
-0.171455130564 0.203968686744 0.876026884182 1.18448132668
-0.179136516363 0.25457484844 0.857672883112 1.15725407675
-0.168851149152 0.196488814157 0.85853365183 1.15100947035
-0.200782681667 0.245520575034 0.815319213692 1.13505933684
-0.196076374254 0.216946720482 0.875178421485 1.14586042992
-0.256058410821 0.252588332087 0.835856111409 1.20354823254
-0.182925761399 0.195474352741 0.842824233653 1.14459453033
-0.214055138057 0.168941626199 0.844726735188 1.13298125855
-0.195325163158 0.224660829554 0.858530846361 1.14165473406
-0.164389995726 0.233696490852 0.824337181215 1.16021691504
-0.168926932499 0.203342206165 0.86365739059 1.12675590037
-0.162744903582 0.225948776182 0.820860826687 1.1540144229
-0.177636087335 0.180090119558 0.868510528365 1.14555843743
-0.185366719749 0.175472965264 0.865415335332 1.15668091424
-0.172144455901 0.198971184807 0.880743440967 1.17422968644
-0.149990556026 0.188368388553 0.868203617649 1.14774137491
-0.184332762657 0.249627322089 0.85248926982 1.1458634033
-0.18322972541 0.228401584125 0.805995120347 1.16455251229
-0.197336426306 0.20840499965 0.883534174274 1.15774321119
-0.201133575704 0.217010114929 0.867981945299 1.13153987447
-0.178305165752 0.177380291659 0.853555202397 1.17495468495
-0.235686944123 0.27563391389 0.872517827639 1.1485206036
-0.238611058438 0.225520759152 0.850135813771 1.1452570092
-0.259778672002 0.217992017853 0.867295338731 1.16102238296
-0.191238396553 0.205245135192 0.841190956359 1.17083074066
-0.199091794564 0.214670972154 0.863521349713 1.16552454106
-0.171686853218 0.165020872897 0.863416790237 1.13921198411
-0.21576394731 0.218156059674 0.85967427098 1.17932862401
-0.211034754928 0.181263462459 0.852238841569 1.15577869761
-0.171734897301 0.279100453093 0.85344826544 1.16221910543
-0.18058173168 0.215356516476 0.812394427257 1.19214139814
-0.22968986685 0.245851465335 0.814961108608 1.13296164351
-0.219671751894 0.215431727004 0.840563229436 1.14072504988
-0.17199595298 0.196332422511 0.820022687474 1.1601672039
-0.16791417979 0.194388700378 0.886054819479 1.18283167625
-0.244062893364 0.228499409558 0.869744300466 1.18297066345
-0.170131325814 0.197957212453 0.871581331834 1.26309722924
-0.221526218891 0.168015691401 0.866273474113 1.16292409454
-0.202304700265 0.268614830259 0.817890519584 1.13329663254
-0.242547048341 0.211104829698 0.836321649289 1.1456543911
-0.210388574893 0.238577685511 0.873284257877 1.15776877689
-0.177084135402 0.199950552718 0.859912283595 1.15676093956
-0.142507001207 0.181483501425 0.850632692514 1.1843000259
-0.226206932601 0.240228493093 0.85312992704 1.14203790583
-0.244211728367 0.230956994942 0.852536135596 1.16515601915
-0.196783355423 0.217203779164 0.835052970249 1.14935186648
-0.173635608596 0.236066523527 0.835443241942 1.17755345428
-0.197964129533 0.173897314788 0.873821626864 1.11564752531
-0.171468212441 0.19839202565 0.840693423865 1.16581235027
-0.164416872794 0.191299332174 0.835008446054 1.15876284314
-0.197619636712 0.231716212086 0.859859459155 1.14281457889
-0.160798227911 0.192000655987 0.851376909838 1.2099392157
-0.195438231889 0.207520012432 0.834400922874 1.22496576616
-0.185863514621 0.23616269701 0.875662296086 1.17353868554
-0.185753292829 0.210278825767 0.882668921186 1.13729341552
-0.172520768407 0.218306559551 0.841147936969 1.16908595573
-0.20651110986 0.163589102113 0.852581600469 1.13359533899
-0.207419326572 0.222878345199 0.868491130856 1.14052416793
-0.173084874526 0.198539726272 0.842677586379 1.14847577783
-0.218996348546 0.17531075816 0.867900683075 1.15273858678
-0.175761437348 0.182870341533 0.849333139767 1.13308471141
-0.150969757134 0.165189388758 0.866035404474 1.1731937917
-0.268348360458 0.230749083185 0.866782049306 1.1533809317

Discussion & wrapup

According to this, even after a 100 layers, mean neuron activations stay fairly close to mean 0 / variance 1 (even the most extreme means/variances are only off by 0.2).

Sepp Hochreiter is amazing: LSTM, meta-learning, SNNN.

I think he has already done a much larger contribution to science than some self-proclaimed pioneers of DL who spend more time on social networks than actually doing any good research.


In [ ]: