In [1]:
import mxnet as mx

In [5]:
import os, urllib
def download(url):
    filename = url.split("/")[-1]
    if not os.path.exists(filename):
        urllib.urlretrieve(url, filename)

In [46]:
def get_iterators(batch_size, data_shape=(3, 224, 224)):
    train = mx.io.ImageRecordIter(
        path_imgrec         = '../data/facesqueeze-train.rec',
        data_name           = 'data',
        label_name          = 'softmax_label',
        batch_size          = batch_size,
        data_shape          = data_shape,
        shuffle             = True,
        rand_crop           = True,
        rotate              = 15,
        rand_mirror         = True)
    val = mx.io.ImageRecordIter(
        path_imgrec         = '../data/facesqueeze-val.rec',
        data_name           = 'data',
        label_name          = 'softmax_label',
        batch_size          = batch_size,
        data_shape          = data_shape,
        rand_crop           = False,
        rand_mirror         = False)
    return (train, val)

In [3]:
def get_model(prefix, epoch):
    download(prefix+'-symbol.json')
    download(prefix+'-%04d.params' % (epoch,))

In [19]:
get_model('http://data.mxnet.io/models/imagenet/squeezenet/squeezenet_v1.1', 0)

In [20]:
sym, arg_params, aux_params = mx.model.load_checkpoint('squeezenet_v1.1', 0)

In [47]:
def get_fine_tune_model(symbol, arg_params, num_classes, layer_name='flatten0'):
    """
    symbol: the pretrained network symbol
    arg_params: the argument parameters of the pretrained model
    num_classes: the number of classes for the fine-tune datasets
    layer_name: the layer name before the last fully-connected layer
    """
    all_layers = symbol.get_internals()
    net = all_layers[layer_name+'_output']
    net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc1')
    net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
    new_args = dict({k:arg_params[k] for k in arg_params if 'fc1' not in k})
    return (net, new_args)

In [48]:
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

def fit(symbol, arg_params, aux_params, train, val, batch_size, num_gpus, num_epochs=8):
    devs = [mx.gpu(i) for i in range(num_gpus)]
    mod = mx.mod.Module(symbol=symbol, context=devs)
    mod.fit(train, val,
        num_epoch=num_epochs,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback = mx.callback.Speedometer(batch_size, 10),
        kvstore='device',
        optimizer='adam',
        epoch_end_callback=mx.callback.do_checkpoint('models/facesqueeze'),
        optimizer_params={'learning_rate':0.001},
        initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2),
        eval_metric='acc')
    metric = mx.metric.Accuracy()
    return mod.score(val, metric)

In [49]:
num_classes = 2
batch_per_gpu = 16
num_gpus = 1

(new_sym, new_args) = get_fine_tune_model(sym, arg_params, num_classes, 'flatten')

batch_size = batch_per_gpu * num_gpus
(train, val) = get_iterators(batch_size)
mod_score = fit(new_sym, new_args, aux_params, train, val, batch_size, num_gpus, 1)
assert mod_score > 0.77, "Low training accuracy."


2017-10-11 17:15:24,763 Epoch[0] Batch [10]	Speed: 198.18 samples/sec	accuracy=0.846591
2017-10-11 17:15:25,464 Epoch[0] Batch [20]	Speed: 228.67 samples/sec	accuracy=0.962500
2017-10-11 17:15:26,170 Epoch[0] Batch [30]	Speed: 227.06 samples/sec	accuracy=0.968750
2017-10-11 17:15:26,873 Epoch[0] Batch [40]	Speed: 228.06 samples/sec	accuracy=0.943750
2017-10-11 17:15:27,579 Epoch[0] Batch [50]	Speed: 227.16 samples/sec	accuracy=0.893750
2017-10-11 17:15:28,286 Epoch[0] Batch [60]	Speed: 226.66 samples/sec	accuracy=0.925000
2017-10-11 17:15:28,994 Epoch[0] Batch [70]	Speed: 227.98 samples/sec	accuracy=0.931250
2017-10-11 17:15:30,948 Epoch[0] Batch [80]	Speed: 82.04 samples/sec	accuracy=0.937500
2017-10-11 17:15:31,649 Epoch[0] Batch [90]	Speed: 229.26 samples/sec	accuracy=0.956250
2017-10-11 17:15:32,350 Epoch[0] Batch [100]	Speed: 228.66 samples/sec	accuracy=0.931250
2017-10-11 17:15:33,051 Epoch[0] Batch [110]	Speed: 229.03 samples/sec	accuracy=0.925000
2017-10-11 17:15:33,752 Epoch[0] Batch [120]	Speed: 228.59 samples/sec	accuracy=0.956250
2017-10-11 17:15:34,455 Epoch[0] Batch [130]	Speed: 228.86 samples/sec	accuracy=0.981250
2017-10-11 17:15:35,162 Epoch[0] Batch [140]	Speed: 227.58 samples/sec	accuracy=0.937500
2017-10-11 17:15:37,084 Epoch[0] Batch [150]	Speed: 83.42 samples/sec	accuracy=0.912500
2017-10-11 17:15:37,786 Epoch[0] Batch [160]	Speed: 229.11 samples/sec	accuracy=0.968750
2017-10-11 17:15:38,491 Epoch[0] Batch [170]	Speed: 227.31 samples/sec	accuracy=0.931250
2017-10-11 17:15:39,195 Epoch[0] Batch [180]	Speed: 227.66 samples/sec	accuracy=0.950000
2017-10-11 17:15:39,897 Epoch[0] Batch [190]	Speed: 228.40 samples/sec	accuracy=0.937500
2017-10-11 17:15:40,598 Epoch[0] Batch [200]	Speed: 229.32 samples/sec	accuracy=0.925000
2017-10-11 17:15:41,304 Epoch[0] Batch [210]	Speed: 227.07 samples/sec	accuracy=0.906250
2017-10-11 17:15:43,026 Epoch[0] Batch [220]	Speed: 93.14 samples/sec	accuracy=0.931250
2017-10-11 17:15:43,723 Epoch[0] Batch [230]	Speed: 230.07 samples/sec	accuracy=0.950000
2017-10-11 17:15:44,427 Epoch[0] Batch [240]	Speed: 227.75 samples/sec	accuracy=0.925000
2017-10-11 17:15:45,129 Epoch[0] Batch [250]	Speed: 229.39 samples/sec	accuracy=0.943750
2017-10-11 17:15:45,830 Epoch[0] Batch [260]	Speed: 228.57 samples/sec	accuracy=0.956250
2017-10-11 17:15:46,533 Epoch[0] Batch [270]	Speed: 228.29 samples/sec	accuracy=0.956250
2017-10-11 17:15:47,241 Epoch[0] Batch [280]	Speed: 227.25 samples/sec	accuracy=0.937500
2017-10-11 17:15:47,950 Epoch[0] Batch [290]	Speed: 227.09 samples/sec	accuracy=0.906250
2017-10-11 17:15:49,801 Epoch[0] Batch [300]	Speed: 86.62 samples/sec	accuracy=0.925000
2017-10-11 17:15:50,506 Epoch[0] Batch [310]	Speed: 227.76 samples/sec	accuracy=0.912500
2017-10-11 17:15:51,208 Epoch[0] Batch [320]	Speed: 228.62 samples/sec	accuracy=0.918750
2017-10-11 17:15:51,909 Epoch[0] Batch [330]	Speed: 228.55 samples/sec	accuracy=0.900000
2017-10-11 17:15:52,617 Epoch[0] Batch [340]	Speed: 226.52 samples/sec	accuracy=0.943750
2017-10-11 17:15:53,328 Epoch[0] Batch [350]	Speed: 226.11 samples/sec	accuracy=0.950000
2017-10-11 17:15:54,034 Epoch[0] Batch [360]	Speed: 229.44 samples/sec	accuracy=0.993750
2017-10-11 17:15:55,741 Epoch[0] Batch [370]	Speed: 93.82 samples/sec	accuracy=0.962500
2017-10-11 17:15:56,444 Epoch[0] Batch [380]	Speed: 227.91 samples/sec	accuracy=0.975000
2017-10-11 17:15:57,146 Epoch[0] Batch [390]	Speed: 229.43 samples/sec	accuracy=0.987500
2017-10-11 17:15:57,851 Epoch[0] Batch [400]	Speed: 228.17 samples/sec	accuracy=0.975000
2017-10-11 17:15:58,557 Epoch[0] Batch [410]	Speed: 227.18 samples/sec	accuracy=0.987500
2017-10-11 17:15:59,261 Epoch[0] Batch [420]	Speed: 228.35 samples/sec	accuracy=0.950000
2017-10-11 17:15:59,969 Epoch[0] Batch [430]	Speed: 226.22 samples/sec	accuracy=0.975000
2017-10-11 17:16:01,832 Epoch[0] Batch [440]	Speed: 85.96 samples/sec	accuracy=0.975000
2017-10-11 17:16:02,534 Epoch[0] Batch [450]	Speed: 229.36 samples/sec	accuracy=0.981250
2017-10-11 17:16:03,240 Epoch[0] Batch [460]	Speed: 226.98 samples/sec	accuracy=0.987500
2017-10-11 17:16:03,946 Epoch[0] Batch [470]	Speed: 226.98 samples/sec	accuracy=0.981250
2017-10-11 17:16:04,653 Epoch[0] Batch [480]	Speed: 227.04 samples/sec	accuracy=0.975000
2017-10-11 17:16:05,357 Epoch[0] Batch [490]	Speed: 228.35 samples/sec	accuracy=0.981250
2017-10-11 17:16:06,069 Epoch[0] Batch [500]	Speed: 225.13 samples/sec	accuracy=0.981250
2017-10-11 17:16:07,913 Epoch[0] Batch [510]	Speed: 86.98 samples/sec	accuracy=0.993750
2017-10-11 17:16:08,577 Epoch[0] Batch [520]	Speed: 241.25 samples/sec	accuracy=0.993750
2017-10-11 17:16:09,286 Epoch[0] Batch [530]	Speed: 226.27 samples/sec	accuracy=0.993750
2017-10-11 17:16:09,990 Epoch[0] Batch [540]	Speed: 227.49 samples/sec	accuracy=0.993750
2017-10-11 17:16:10,694 Epoch[0] Batch [550]	Speed: 227.79 samples/sec	accuracy=0.968750
2017-10-11 17:16:11,400 Epoch[0] Batch [560]	Speed: 226.99 samples/sec	accuracy=0.993750
2017-10-11 17:16:12,108 Epoch[0] Batch [570]	Speed: 226.43 samples/sec	accuracy=0.975000
2017-10-11 17:16:12,822 Epoch[0] Batch [580]	Speed: 224.57 samples/sec	accuracy=0.975000
2017-10-11 17:16:14,662 Epoch[0] Batch [590]	Speed: 87.03 samples/sec	accuracy=0.968750
2017-10-11 17:16:15,370 Epoch[0] Batch [600]	Speed: 227.45 samples/sec	accuracy=0.993750
2017-10-11 17:16:16,081 Epoch[0] Batch [610]	Speed: 225.54 samples/sec	accuracy=0.993750
2017-10-11 17:16:16,785 Epoch[0] Batch [620]	Speed: 227.45 samples/sec	accuracy=1.000000
2017-10-11 17:16:17,489 Epoch[0] Batch [630]	Speed: 227.76 samples/sec	accuracy=0.968750
2017-10-11 17:16:18,198 Epoch[0] Batch [640]	Speed: 226.17 samples/sec	accuracy=1.000000
2017-10-11 17:16:18,913 Epoch[0] Batch [650]	Speed: 225.08 samples/sec	accuracy=0.956250
2017-10-11 17:16:20,723 Epoch[0] Batch [660]	Speed: 88.55 samples/sec	accuracy=0.956250
2017-10-11 17:16:21,426 Epoch[0] Batch [670]	Speed: 228.66 samples/sec	accuracy=0.981250
2017-10-11 17:16:22,135 Epoch[0] Batch [680]	Speed: 225.98 samples/sec	accuracy=0.975000
2017-10-11 17:16:22,841 Epoch[0] Batch [690]	Speed: 227.08 samples/sec	accuracy=0.993750
2017-10-11 17:16:23,551 Epoch[0] Batch [700]	Speed: 225.94 samples/sec	accuracy=1.000000
2017-10-11 17:16:24,261 Epoch[0] Batch [710]	Speed: 225.82 samples/sec	accuracy=0.987500
2017-10-11 17:16:24,976 Epoch[0] Batch [720]	Speed: 224.31 samples/sec	accuracy=0.981250
2017-10-11 17:16:26,751 Epoch[0] Batch [730]	Speed: 90.23 samples/sec	accuracy=0.981250
2017-10-11 17:16:27,455 Epoch[0] Batch [740]	Speed: 227.49 samples/sec	accuracy=1.000000
2017-10-11 17:16:28,162 Epoch[0] Batch [750]	Speed: 226.79 samples/sec	accuracy=0.975000
2017-10-11 17:16:28,871 Epoch[0] Batch [760]	Speed: 226.18 samples/sec	accuracy=0.987500
2017-10-11 17:16:29,579 Epoch[0] Batch [770]	Speed: 226.47 samples/sec	accuracy=0.987500
2017-10-11 17:16:30,290 Epoch[0] Batch [780]	Speed: 225.44 samples/sec	accuracy=1.000000
2017-10-11 17:16:31,005 Epoch[0] Batch [790]	Speed: 225.06 samples/sec	accuracy=1.000000
2017-10-11 17:16:31,720 Epoch[0] Batch [800]	Speed: 225.02 samples/sec	accuracy=0.993750
2017-10-11 17:16:32,605 Epoch[0] Batch [810]	Speed: 181.07 samples/sec	accuracy=1.000000
2017-10-11 17:16:33,312 Epoch[0] Batch [820]	Speed: 226.72 samples/sec	accuracy=0.993750
2017-10-11 17:16:34,023 Epoch[0] Batch [830]	Speed: 225.48 samples/sec	accuracy=0.993750
2017-10-11 17:16:34,734 Epoch[0] Batch [840]	Speed: 226.29 samples/sec	accuracy=1.000000
2017-10-11 17:16:35,851 Epoch[0] Train-accuracy=1.000000
2017-10-11 17:16:35,852 Epoch[0] Time cost=71.978
2017-10-11 17:16:35,917 Saved checkpoint to "models/facesqueeze-0001.params"
2017-10-11 17:16:38,736 Epoch[0] Validation-accuracy=0.984043

In [50]:
model_loaded = mx.mod.Module.load('models/facesqueeze', 1)

In [51]:
print model_loaded


<mxnet.module.module.Module object at 0x7f111c443ed0>

In [53]:
mx.viz.plot_network(new_sym)


Out[53]:
plot data data conv1 Convolution 3x3/2x2, 64 conv1->data relu_conv1 Activation relu relu_conv1->conv1 pool1 Pooling max, 3x3/2x2 pool1->relu_conv1 fire2_squeeze1x1 Convolution 1x1/1x1, 16 fire2_squeeze1x1->pool1 fire2_relu_squeeze1x1 Activation relu fire2_relu_squeeze1x1->fire2_squeeze1x1 fire2_expand1x1 Convolution 1x1/1x1, 64 fire2_expand1x1->fire2_relu_squeeze1x1 fire2_relu_expand1x1 Activation relu fire2_relu_expand1x1->fire2_expand1x1 fire2_expand3x3 Convolution 3x3/1x1, 64 fire2_expand3x3->fire2_relu_squeeze1x1 fire2_relu_expand3x3 Activation relu fire2_relu_expand3x3->fire2_expand3x3 fire2_concat fire2_concat fire2_concat->fire2_relu_expand1x1 fire2_concat->fire2_relu_expand3x3 fire3_squeeze1x1 Convolution 1x1/1x1, 16 fire3_squeeze1x1->fire2_concat fire3_relu_squeeze1x1 Activation relu fire3_relu_squeeze1x1->fire3_squeeze1x1 fire3_expand1x1 Convolution 1x1/1x1, 64 fire3_expand1x1->fire3_relu_squeeze1x1 fire3_relu_expand1x1 Activation relu fire3_relu_expand1x1->fire3_expand1x1 fire3_expand3x3 Convolution 3x3/1x1, 64 fire3_expand3x3->fire3_relu_squeeze1x1 fire3_relu_expand3x3 Activation relu fire3_relu_expand3x3->fire3_expand3x3 fire3_concat fire3_concat fire3_concat->fire3_relu_expand1x1 fire3_concat->fire3_relu_expand3x3 pool3 Pooling max, 3x3/2x2 pool3->fire3_concat fire4_squeeze1x1 Convolution 1x1/1x1, 32 fire4_squeeze1x1->pool3 fire4_relu_squeeze1x1 Activation relu fire4_relu_squeeze1x1->fire4_squeeze1x1 fire4_expand1x1 Convolution 1x1/1x1, 128 fire4_expand1x1->fire4_relu_squeeze1x1 fire4_relu_expand1x1 Activation relu fire4_relu_expand1x1->fire4_expand1x1 fire4_expand3x3 Convolution 3x3/1x1, 128 fire4_expand3x3->fire4_relu_squeeze1x1 fire4_relu_expand3x3 Activation relu fire4_relu_expand3x3->fire4_expand3x3 fire4_concat fire4_concat fire4_concat->fire4_relu_expand1x1 fire4_concat->fire4_relu_expand3x3 fire5_squeeze1x1 Convolution 1x1/1x1, 32 fire5_squeeze1x1->fire4_concat fire5_relu_squeeze1x1 Activation relu fire5_relu_squeeze1x1->fire5_squeeze1x1 fire5_expand1x1 Convolution 1x1/1x1, 128 fire5_expand1x1->fire5_relu_squeeze1x1 fire5_relu_expand1x1 Activation relu fire5_relu_expand1x1->fire5_expand1x1 fire5_expand3x3 Convolution 3x3/1x1, 128 fire5_expand3x3->fire5_relu_squeeze1x1 fire5_relu_expand3x3 Activation relu fire5_relu_expand3x3->fire5_expand3x3 fire5_concat fire5_concat fire5_concat->fire5_relu_expand1x1 fire5_concat->fire5_relu_expand3x3 pool5 Pooling max, 3x3/2x2 pool5->fire5_concat fire6_squeeze1x1 Convolution 1x1/1x1, 48 fire6_squeeze1x1->pool5 fire6_relu_squeeze1x1 Activation relu fire6_relu_squeeze1x1->fire6_squeeze1x1 fire6_expand1x1 Convolution 1x1/1x1, 192 fire6_expand1x1->fire6_relu_squeeze1x1 fire6_relu_expand1x1 Activation relu fire6_relu_expand1x1->fire6_expand1x1 fire6_expand3x3 Convolution 3x3/1x1, 192 fire6_expand3x3->fire6_relu_squeeze1x1 fire6_relu_expand3x3 Activation relu fire6_relu_expand3x3->fire6_expand3x3 fire6_concat fire6_concat fire6_concat->fire6_relu_expand1x1 fire6_concat->fire6_relu_expand3x3 fire7_squeeze1x1 Convolution 1x1/1x1, 48 fire7_squeeze1x1->fire6_concat fire7_relu_squeeze1x1 Activation relu fire7_relu_squeeze1x1->fire7_squeeze1x1 fire7_expand1x1 Convolution 1x1/1x1, 192 fire7_expand1x1->fire7_relu_squeeze1x1 fire7_relu_expand1x1 Activation relu fire7_relu_expand1x1->fire7_expand1x1 fire7_expand3x3 Convolution 3x3/1x1, 192 fire7_expand3x3->fire7_relu_squeeze1x1 fire7_relu_expand3x3 Activation relu fire7_relu_expand3x3->fire7_expand3x3 fire7_concat fire7_concat fire7_concat->fire7_relu_expand1x1 fire7_concat->fire7_relu_expand3x3 fire8_squeeze1x1 Convolution 1x1/1x1, 64 fire8_squeeze1x1->fire7_concat fire8_relu_squeeze1x1 Activation relu fire8_relu_squeeze1x1->fire8_squeeze1x1 fire8_expand1x1 Convolution 1x1/1x1, 256 fire8_expand1x1->fire8_relu_squeeze1x1 fire8_relu_expand1x1 Activation relu fire8_relu_expand1x1->fire8_expand1x1 fire8_expand3x3 Convolution 3x3/1x1, 256 fire8_expand3x3->fire8_relu_squeeze1x1 fire8_relu_expand3x3 Activation relu fire8_relu_expand3x3->fire8_expand3x3 fire8_concat fire8_concat fire8_concat->fire8_relu_expand1x1 fire8_concat->fire8_relu_expand3x3 fire9_squeeze1x1 Convolution 1x1/1x1, 64 fire9_squeeze1x1->fire8_concat fire9_relu_squeeze1x1 Activation relu fire9_relu_squeeze1x1->fire9_squeeze1x1 fire9_expand1x1 Convolution 1x1/1x1, 256 fire9_expand1x1->fire9_relu_squeeze1x1 fire9_relu_expand1x1 Activation relu fire9_relu_expand1x1->fire9_expand1x1 fire9_expand3x3 Convolution 3x3/1x1, 256 fire9_expand3x3->fire9_relu_squeeze1x1 fire9_relu_expand3x3 Activation relu fire9_relu_expand3x3->fire9_expand3x3 fire9_concat fire9_concat fire9_concat->fire9_relu_expand1x1 fire9_concat->fire9_relu_expand3x3 drop9 drop9 drop9->fire9_concat conv10 Convolution 1x1/1x1, 1000 conv10->drop9 relu_conv10 Activation relu relu_conv10->conv10 pool10 Pooling avg, 1x1/1x1 pool10->relu_conv10 flatten flatten flatten->pool10 fc1 FullyConnected 2 fc1->flatten softmax_label softmax_label softmax softmax softmax->fc1 softmax->softmax_label

Evaluate

Ok. So we have 98.40% accuracy on our validation set. Now let's see how our model performs with images pulled from the internet.


In [2]:
import mxnet as mx

model_loaded = mx.mod.Module.load('models/facesqueeze', 1)
model_loaded.bind(for_training=False, data_shapes= [('data', (1,3,224,224))])


/Users/nickrobi/.virtualenvs/reinvent/lib/python2.7/site-packages/mxnet/module/base_module.py:65: UserWarning: Data provided by label_shapes don't match names specified by label_names ([] vs. ['softmax_label'])
  warnings.warn(msg)

In [4]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import time

from collections import namedtuple
Batch = namedtuple('Batch', ['data'])

reshape=(224, 224)


/Users/nickrobi/.virtualenvs/reinvent/lib/python2.7/site-packages/matplotlib/font_manager.py:279: UserWarning: Matplotlib is building the font cache using fc-list. This may take a moment.
  'Matplotlib is building the font cache using fc-list. '

In [5]:
filename = '../examples/george_bush.jpg'

img = cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2RGB)

In [6]:
plt.imshow(img)


Out[6]:
<matplotlib.image.AxesImage at 0x112aae9d0>

In [7]:
# Run forward on the image
# Resize image to fit network input
start_time = time.time()
img = cv2.resize(img, reshape)
img = np.swapaxes(img, 0, 2)
img = np.swapaxes(img, 1, 2)
img = img[np.newaxis, :]



model_loaded.forward(Batch([mx.nd.array(img)]))
prob = model_loaded.get_outputs()[0].asnumpy()
prob = np.squeeze(prob)
print("--- %s seconds ---" % (time.time() - start_time))


--- 0.0463259220123 seconds ---

In [9]:
prob


Out[9]:
array([  9.99988794e-01,   1.11481868e-05], dtype=float32)

In [10]:
filename = '../examples/condoleezza_rice.jpg'
img = cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2RGB)

In [11]:
plt.imshow(img)


Out[11]:
<matplotlib.image.AxesImage at 0x11589ba10>

In [12]:
# Run forward on the image
# Resize image to fit network input
start_time = time.time()
img = cv2.resize(img, reshape)
img = np.swapaxes(img, 0, 2)
img = np.swapaxes(img, 1, 2)
img = img[np.newaxis, :]



model_loaded.forward(Batch([mx.nd.array(img)]))
prob = model_loaded.get_outputs()[0].asnumpy()
prob = np.squeeze(prob)
print("--- %s seconds ---" % (time.time() - start_time))


--- 0.0598978996277 seconds ---

In [13]:
prob


Out[13]:
array([  1.00000000e+00,   2.21342259e-12], dtype=float32)

In [ ]: