Binary classification with logistic regression


In [1]:
import mxnet as mx 
from mxnet import gluon
from tqdm import tqdm

In [2]:
# Sigmoid function
def logistic(z):
    return 1. / (1. + mx.nd.exp(-z))

In [3]:
data_ctx = mx.cpu()
# Change this to `mx.gpu(0) if you would like to train on an NVIDIA GPU
model_ctx = mx.cpu()

In [4]:
with open("./data/adult/a1a.train") as f:
    train_raw = f.read()

with open("./data/adult/a1a.test") as f:
    test_raw = f.read()

In [5]:
train_raw[:200]


Out[5]:
'-1 5:1 7:1 14:1 19:1 39:1 40:1 51:1 63:1 67:1 73:1 74:1 76:1 78:1 83:1 \n-1 3:1 6:1 17:1 22:1 36:1 41:1 53:1 64:1 67:1 73:1 74:1 76:1 80:1 83:1 \n-1 5:1 6:1 17:1 21:1 35:1 40:1 53:1 63:1 71:1 73:1 74:1 '

In [6]:
test_raw[:200]


Out[6]:
'-1 3:1 11:1 14:1 19:1 39:1 42:1 55:1 64:1 67:1 73:1 75:1 76:1 80:1 83:1 \n-1 3:1 6:1 17:1 27:1 35:1 40:1 57:1 63:1 69:1 73:1 74:1 76:1 81:1 103:1 \n-1 4:1 6:1 15:1 21:1 35:1 40:1 57:1 63:1 67:1 73:1 74:'

In [7]:
def process_data(raw_data):
    train_lines = raw_data.splitlines()
    num_examples = len(train_lines)
    num_features = 123
    X = mx.nd.zeros((num_examples, num_features), ctx=data_ctx)
    Y = mx.nd.zeros((num_examples, 1), ctx=data_ctx)
    for i, line in enumerate(train_lines):
        tokens = line.split()
        label = (int(tokens[0]) + 1) / 2  # Change label from {-1,1} to {0,1}
        Y[i] = label
        for token in tokens[1:]:
            index = int(token[:-2]) - 1
            X[i, index] = 1 
    return X, Y

In [8]:
Xtrain, Ytrain = process_data(train_raw)
Xtest, Ytest = process_data(test_raw)

In [9]:
Xtrain.shape


Out[9]:
(30956, 123)

In [10]:
Xtest.shape


Out[10]:
(1605, 123)

Instatiate a dataloader


In [11]:
# Setting batch_size
batch_size  = 64

In [12]:
train_data = gluon.data.DataLoader(gluon.data.ArrayDataset(Xtrain, Ytrain),
                                      batch_size=batch_size, shuffle=True)
test_data = gluon.data.DataLoader(gluon.data.ArrayDataset(Xtest, Ytest),
                                      batch_size=batch_size, shuffle=True)

Define a model


In [13]:
net = gluon.nn.Dense(1)

In [14]:
# Collect the parameters
net.collect_params().initialize(mx.init.Normal(sigma=1.), ctx=model_ctx)

In [15]:
# Optimizer
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.01})

In [16]:
# Log loss
def log_loss(output, y):
    yhat = logistic(output)
    # Method mx.nd.nansum:
    # Computes the sum of array elements over given axes treating Not a Numbers (NaN) as zero.
    return  - mx.nd.nansum(  y * mx.nd.log(yhat) + (1-y) * mx.nd.log(1 - yhat))

Training


In [17]:
epochs = 200
loss_sequence = []
num_examples = len(Xtrain)

In [18]:
for e in range(epochs):
    cumulative_loss = 0
    for i, (data, label) in tqdm(enumerate(train_data)):
        data = data.as_in_context(model_ctx)
        label = label.as_in_context(model_ctx)
        with mx.autograd.record():
            output = net(data)
            loss = log_loss(output, label)
        loss.backward()
        trainer.step(batch_size)
        cumulative_loss += mx.nd.sum(loss).asscalar()
    print("Epoch %s, loss: %s" % (e, cumulative_loss ))
    loss_sequence.append(cumulative_loss)


484it [00:03, 130.50it/s]
Epoch 0, loss: 29936.726289749146
484it [00:03, 138.93it/s]
Epoch 1, loss: 23867.01731109619
484it [00:03, 137.62it/s]
Epoch 2, loss: 21486.378536224365
484it [00:03, 139.17it/s]
Epoch 3, loss: 19917.283807754517
484it [00:03, 139.85it/s]
Epoch 4, loss: 18771.711921691895
484it [00:03, 133.60it/s]
Epoch 5, loss: 17872.3711271286
484it [00:03, 128.66it/s]
Epoch 6, loss: 17129.806741714478
484it [00:03, 139.33it/s]
Epoch 7, loss: 16499.878571510315
484it [00:03, 127.44it/s]
Epoch 8, loss: 15952.956929206848
484it [00:03, 127.31it/s]
Epoch 9, loss: 15470.944918632507
484it [00:03, 131.82it/s]
Epoch 10, loss: 15045.158886909485
484it [00:03, 133.67it/s]
Epoch 11, loss: 14663.919858932495
484it [00:03, 128.47it/s]
Epoch 12, loss: 14322.057542800903
484it [00:03, 140.62it/s]
Epoch 13, loss: 14017.962821006775
484it [00:03, 134.22it/s]
Epoch 14, loss: 13740.933804512024
484it [00:03, 136.19it/s]
Epoch 15, loss: 13491.090450286865
484it [00:03, 139.97it/s]
Epoch 16, loss: 13264.92767906189
484it [00:03, 131.89it/s]
Epoch 17, loss: 13059.28347492218
484it [00:03, 132.39it/s]
Epoch 18, loss: 12871.091656684875
484it [00:03, 136.64it/s]
Epoch 19, loss: 12699.699897766113
484it [00:03, 137.55it/s]
Epoch 20, loss: 12543.665016174316
484it [00:03, 132.86it/s]
Epoch 21, loss: 12399.551770210266
484it [00:03, 138.69it/s]
Epoch 22, loss: 12266.57674407959
484it [00:03, 136.85it/s]
Epoch 23, loss: 12146.741287231445
484it [00:03, 135.07it/s]
Epoch 24, loss: 12033.386477470398
484it [00:03, 135.01it/s]
Epoch 25, loss: 11929.380443572998
484it [00:03, 129.63it/s]
Epoch 26, loss: 11833.195343971252
484it [00:03, 140.58it/s]
Epoch 27, loss: 11743.183678627014
484it [00:03, 137.47it/s]
Epoch 28, loss: 11660.394174575806
484it [00:03, 143.97it/s]
Epoch 29, loss: 11582.283012390137
484it [00:03, 139.78it/s]
Epoch 30, loss: 11509.966297149658
484it [00:03, 137.52it/s]
Epoch 31, loss: 11442.245765686035
484it [00:03, 136.37it/s]
Epoch 32, loss: 11378.733708381653
484it [00:03, 133.80it/s]
Epoch 33, loss: 11319.877409934998
484it [00:03, 143.71it/s]
Epoch 34, loss: 11264.349200248718
484it [00:03, 138.45it/s]
Epoch 35, loss: 11211.960681915283
484it [00:03, 135.69it/s]
Epoch 36, loss: 11163.19675064087
484it [00:03, 140.66it/s]
Epoch 37, loss: 11116.533200263977
484it [00:03, 138.89it/s]
Epoch 38, loss: 11073.058456420898
484it [00:03, 141.24it/s]
Epoch 39, loss: 11032.066408157349
484it [00:03, 135.01it/s]
Epoch 40, loss: 10993.28828907013
484it [00:03, 141.53it/s]
Epoch 41, loss: 10954.737637519836
484it [00:03, 140.78it/s]
Epoch 42, loss: 10921.863587379456
484it [00:03, 141.40it/s]
Epoch 43, loss: 10888.699221611023
484it [00:03, 139.61it/s]
Epoch 44, loss: 10857.674038887024
484it [00:03, 130.05it/s]
Epoch 45, loss: 10827.324669837952
484it [00:03, 138.67it/s]
Epoch 46, loss: 10799.80136013031
484it [00:03, 139.93it/s]
Epoch 47, loss: 10771.07808971405
484it [00:03, 139.81it/s]
Epoch 48, loss: 10748.398710250854
484it [00:03, 141.03it/s]
Epoch 49, loss: 10724.292037010193
484it [00:03, 130.41it/s]
Epoch 50, loss: 10700.560848236084
484it [00:03, 141.20it/s]
Epoch 51, loss: 10680.157524108887
484it [00:03, 140.83it/s]
Epoch 52, loss: 10658.5714635849
484it [00:03, 139.41it/s]
Epoch 53, loss: 10638.530358314514
484it [00:03, 123.83it/s]
Epoch 54, loss: 10618.113799095154
484it [00:03, 138.75it/s]
Epoch 55, loss: 10601.977139472961
484it [00:03, 144.96it/s]
Epoch 56, loss: 10584.559760093689
484it [00:03, 135.92it/s]
Epoch 57, loss: 10568.383726119995
484it [00:03, 143.10it/s]
Epoch 58, loss: 10552.147817611694
484it [00:03, 138.95it/s]
Epoch 59, loss: 10537.376557350159
484it [00:03, 131.26it/s]
Epoch 60, loss: 10522.82966041565
484it [00:03, 139.16it/s]
Epoch 61, loss: 10509.26347064972
484it [00:03, 139.78it/s]
Epoch 62, loss: 10495.909856796265
484it [00:03, 136.96it/s]
Epoch 63, loss: 10483.420195579529
484it [00:03, 140.17it/s]
Epoch 64, loss: 10471.076102256775
484it [00:03, 140.32it/s]
Epoch 65, loss: 10459.586990356445
484it [00:03, 143.11it/s]
Epoch 66, loss: 10448.292103767395
484it [00:03, 140.91it/s]
Epoch 67, loss: 10437.926148414612
484it [00:03, 135.89it/s]
Epoch 68, loss: 10427.621927261353
484it [00:03, 142.17it/s]
Epoch 69, loss: 10416.855694770813
484it [00:03, 140.18it/s]
Epoch 70, loss: 10408.552763938904
484it [00:03, 143.18it/s]
Epoch 71, loss: 10398.382985115051
484it [00:03, 136.11it/s]
Epoch 72, loss: 10390.082204818726
484it [00:03, 136.41it/s]
Epoch 73, loss: 10382.142388343811
484it [00:03, 135.89it/s]
Epoch 74, loss: 10374.0715675354
484it [00:03, 134.78it/s]
Epoch 75, loss: 10366.325758934021
484it [00:03, 123.79it/s]
Epoch 76, loss: 10358.953506469727
484it [00:03, 134.05it/s]
Epoch 77, loss: 10351.976437568665
484it [00:03, 132.72it/s]
Epoch 78, loss: 10343.686310768127
484it [00:03, 127.34it/s]
Epoch 79, loss: 10338.413185119629
484it [00:03, 139.53it/s]
Epoch 80, loss: 10331.657397270203
484it [00:03, 143.28it/s]
Epoch 81, loss: 10325.831727027893
484it [00:04, 120.56it/s]
Epoch 82, loss: 10319.813669204712
484it [00:02, 167.30it/s]
Epoch 83, loss: 10313.299562454224
484it [00:02, 188.90it/s]
Epoch 84, loss: 10308.451982498169
484it [00:04, 108.57it/s]
Epoch 85, loss: 10303.713011741638
484it [00:03, 140.74it/s]
Epoch 86, loss: 10297.53249168396
484it [00:03, 144.18it/s]
Epoch 87, loss: 10293.451174736023
484it [00:03, 122.05it/s]
Epoch 88, loss: 10288.752130508423
484it [00:05, 95.19it/s]
Epoch 89, loss: 10283.711834907532
484it [00:04, 109.66it/s]
Epoch 90, loss: 10279.439653396606
484it [00:02, 163.37it/s]
Epoch 91, loss: 10275.265036582947
484it [00:02, 171.44it/s]
Epoch 92, loss: 10271.12265586853
484it [00:02, 166.09it/s]
Epoch 93, loss: 10267.208065986633
484it [00:04, 99.17it/s]
Epoch 94, loss: 10262.731742858887
484it [00:04, 100.15it/s]
Epoch 95, loss: 10259.26947593689
484it [00:03, 121.76it/s]
Epoch 96, loss: 10255.697553634644
484it [00:04, 108.40it/s]
Epoch 97, loss: 10250.644655227661
484it [00:03, 121.49it/s]
Epoch 98, loss: 10248.82534456253
484it [00:03, 129.14it/s]
Epoch 99, loss: 10245.587275505066
484it [00:03, 128.63it/s]
Epoch 100, loss: 10242.449810028076
484it [00:03, 134.34it/s]
Epoch 101, loss: 10238.250332832336
484it [00:03, 127.92it/s]
Epoch 102, loss: 10235.829461097717
484it [00:03, 126.05it/s]
Epoch 103, loss: 10231.319738388062
484it [00:03, 134.86it/s]
Epoch 104, loss: 10229.264724254608
484it [00:03, 128.70it/s]
Epoch 105, loss: 10227.132370948792
484it [00:03, 125.58it/s]
Epoch 106, loss: 10224.072878837585
484it [00:03, 130.84it/s]
Epoch 107, loss: 10222.402024269104
484it [00:03, 123.69it/s]
Epoch 108, loss: 10220.121809005737
484it [00:03, 131.21it/s]
Epoch 109, loss: 10216.741118431091
484it [00:03, 128.63it/s]
Epoch 110, loss: 10214.60348033905
484it [00:03, 136.61it/s]
Epoch 111, loss: 10213.048458099365
484it [00:03, 130.38it/s]
Epoch 112, loss: 10210.40582561493
484it [00:03, 133.28it/s]
Epoch 113, loss: 10207.925101280212
484it [00:03, 132.23it/s]
Epoch 114, loss: 10205.956387519836
484it [00:03, 132.17it/s]
Epoch 115, loss: 10204.037182807922
484it [00:03, 130.95it/s]
Epoch 116, loss: 10201.776706695557
484it [00:03, 136.99it/s]
Epoch 117, loss: 10200.900031089783
484it [00:03, 132.72it/s]
Epoch 118, loss: 10198.702571868896
484it [00:03, 133.23it/s]
Epoch 119, loss: 10196.41390323639
484it [00:03, 129.25it/s]
Epoch 120, loss: 10194.957934379578
484it [00:03, 133.44it/s]
Epoch 121, loss: 10192.859848022461
484it [00:03, 137.90it/s]
Epoch 122, loss: 10191.397868156433
484it [00:03, 133.97it/s]
Epoch 123, loss: 10188.97904586792
484it [00:03, 130.19it/s]
Epoch 124, loss: 10187.763974189758
484it [00:03, 137.82it/s]
Epoch 125, loss: 10186.410312652588
484it [00:03, 136.34it/s]
Epoch 126, loss: 10185.241122245789
484it [00:03, 130.19it/s]
Epoch 127, loss: 10182.71462059021
484it [00:03, 133.15it/s]
Epoch 128, loss: 10181.85647392273
484it [00:03, 130.86it/s]
Epoch 129, loss: 10178.66836643219
484it [00:03, 134.25it/s]
Epoch 130, loss: 10179.601861000061
484it [00:03, 133.52it/s]
Epoch 131, loss: 10177.559452056885
484it [00:03, 129.73it/s]
Epoch 132, loss: 10175.972560882568
484it [00:03, 133.16it/s]
Epoch 133, loss: 10175.475054740906
484it [00:03, 134.60it/s]
Epoch 134, loss: 10174.201550483704
484it [00:03, 127.22it/s]
Epoch 135, loss: 10172.93392086029
484it [00:03, 125.33it/s]
Epoch 136, loss: 10171.564301490784
484it [00:03, 132.76it/s]
Epoch 137, loss: 10169.914498329163
484it [00:03, 135.66it/s]
Epoch 138, loss: 10168.467583656311
484it [00:03, 130.71it/s]
Epoch 139, loss: 10167.51544380188
484it [00:03, 141.88it/s]
Epoch 140, loss: 10166.431834220886
484it [00:03, 134.68it/s]
Epoch 141, loss: 10165.346145629883
484it [00:03, 138.80it/s]
Epoch 142, loss: 10163.421656608582
484it [00:03, 131.64it/s]
Epoch 143, loss: 10163.34135723114
484it [00:03, 132.14it/s]
Epoch 144, loss: 10162.460638999939
484it [00:03, 132.18it/s]
Epoch 145, loss: 10161.704993247986
484it [00:03, 128.73it/s]
Epoch 146, loss: 10159.709155082703
484it [00:03, 132.14it/s]
Epoch 147, loss: 10159.753462791443
484it [00:03, 128.42it/s]
Epoch 148, loss: 10158.07068824768
484it [00:03, 134.12it/s]
Epoch 149, loss: 10157.532780647278
484it [00:03, 129.96it/s]
Epoch 150, loss: 10156.009618759155
484it [00:03, 133.23it/s]
Epoch 151, loss: 10154.553248405457
484it [00:04, 114.03it/s]
Epoch 152, loss: 10154.287781715393
484it [00:03, 136.58it/s]
Epoch 153, loss: 10152.938183784485
484it [00:03, 124.66it/s]
Epoch 154, loss: 10153.027048587799
484it [00:03, 132.07it/s]
Epoch 155, loss: 10151.065587997437
484it [00:03, 132.00it/s]
Epoch 156, loss: 10151.655452251434
484it [00:02, 184.35it/s]
Epoch 157, loss: 10150.53726863861
484it [00:02, 203.96it/s]
Epoch 158, loss: 10149.506249427795
484it [00:02, 209.95it/s]
Epoch 159, loss: 10149.060564994812
484it [00:02, 209.87it/s]
Epoch 160, loss: 10148.002763748169
484it [00:02, 210.05it/s]
Epoch 161, loss: 10147.059515953064
484it [00:02, 216.43it/s]
Epoch 162, loss: 10146.324390411377
484it [00:02, 216.72it/s]
Epoch 163, loss: 10145.671229362488
484it [00:02, 211.05it/s]
Epoch 164, loss: 10144.500095367432
484it [00:02, 211.42it/s]
Epoch 165, loss: 10143.911502838135
484it [00:02, 209.41it/s]
Epoch 166, loss: 10143.94275188446
484it [00:02, 210.50it/s]
Epoch 167, loss: 10141.941633224487
484it [00:02, 216.34it/s]
Epoch 168, loss: 10141.711802482605
484it [00:02, 214.42it/s]
Epoch 169, loss: 10141.115005493164
484it [00:02, 213.76it/s]
Epoch 170, loss: 10140.420946121216
484it [00:02, 212.35it/s]
Epoch 171, loss: 10139.905055999756
484it [00:02, 215.66it/s]
Epoch 172, loss: 10138.96269607544
484it [00:02, 216.05it/s]
Epoch 173, loss: 10138.522979736328
484it [00:02, 217.60it/s]
Epoch 174, loss: 10137.982942581177
484it [00:02, 214.13it/s]
Epoch 175, loss: 10136.905864715576
484it [00:02, 213.95it/s]
Epoch 176, loss: 10136.745372772217
484it [00:02, 218.19it/s]
Epoch 177, loss: 10136.15657901764
484it [00:02, 218.78it/s]
Epoch 178, loss: 10135.451171875
484it [00:02, 212.54it/s]
Epoch 179, loss: 10134.250994682312
484it [00:02, 202.15it/s]
Epoch 180, loss: 10133.993670463562
484it [00:02, 204.29it/s]
Epoch 181, loss: 10133.26036453247
484it [00:02, 212.07it/s]
Epoch 182, loss: 10132.998580932617
484it [00:02, 204.03it/s]
Epoch 183, loss: 10132.204723358154
484it [00:02, 211.33it/s]
Epoch 184, loss: 10132.10848903656
484it [00:02, 211.33it/s]
Epoch 185, loss: 10131.43751335144
484it [00:02, 214.70it/s]
Epoch 186, loss: 10130.86072731018
484it [00:02, 215.47it/s]
Epoch 187, loss: 10129.626757621765
484it [00:02, 216.05it/s]
Epoch 188, loss: 10128.818127632141
484it [00:02, 212.72it/s]
Epoch 189, loss: 10128.039769172668
484it [00:02, 216.43it/s]
Epoch 190, loss: 10127.86229801178
484it [00:02, 214.90it/s]
Epoch 191, loss: 10126.114906311035
484it [00:02, 213.19it/s]
Epoch 192, loss: 10128.072587966919
484it [00:02, 215.95it/s]
Epoch 193, loss: 10125.858869552612
484it [00:02, 212.26it/s]
Epoch 194, loss: 10126.194302558899
484it [00:02, 210.96it/s]
Epoch 195, loss: 10125.431565284729
484it [00:02, 209.68it/s]
Epoch 196, loss: 10125.126355171204
484it [00:02, 213.00it/s]
Epoch 197, loss: 10124.401735305786
484it [00:02, 213.95it/s]
Epoch 198, loss: 10123.679330825806
484it [00:02, 214.32it/s]
Epoch 199, loss: 10123.809837341309

Calculating accuracy


In [19]:
num_correct = 0.0
num_total = len(Xtest)
for i, (data, label) in enumerate(test_data):
    data = data.as_in_context(model_ctx)
    label = label.as_in_context(model_ctx)
    output = net(data)
    prediction = (mx.nd.sign(output) + 1) / 2
    num_correct += mx.nd.sum(prediction == label)
print("Accuracy: %0.3f (%s/%s)" % (num_correct.asscalar() / num_total, num_correct.asscalar(), num_total))


Accuracy: 0.840 (1349.0/1605)