Building a prediction API using a pre-trained MXNet model and deploying using SAM

@author: Sunil Mallya

more info: https://github.com/awslabs/mxnet-lambda/

This notebook simplifies the developement enviroment and deployment to AWS Lambda using SAM. You can modify the code, update Lambda function and deploy all using Jupyter. Its Neat!


In [1]:
'''
Reference code to showcase MXNet model prediction on AWS Lambda 
'''

import base64
import os
import boto3
import json
import tempfile
import urllib2 

# Check if Lambda Function
if os.environ.get('LAMBDA_TASK_ROOT') is None:
    print "just exit, we are not in a lambda function",
    import sys; sys.exit(0)

import mxnet as mx
import numpy as np

from PIL import Image
from io import BytesIO
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])

# Gloabls
grids, ground_truth = [], {}

f_params = 'geo/RN101-5k500-0012.params'
f_symbol = 'geo/RN101-5k500-symbol.json'
    
bucket = 'smallya-test'
s3 = boto3.resource('s3')
s3_client = boto3.client('s3')

# load labels
with open('grids.txt', 'r') as f:
    for line in f:
        line = line.strip().split('\t')
        lat = float(line[1])
        lng = float(line[2])
        grids.append((lat, lng))

# Load model
def load_model(s_fname, p_fname):
    """
    Load model checkpoint from file.
    :return: (arg_params, aux_params)
    arg_params : dict of str to NDArray
        Model parameter, dict of name to NDArray of net's weights.
    aux_params : dict of str to NDArray
        Model parameter, dict of name to NDArray of net's auxiliary states.
    """
    symbol = mx.symbol.load(s_fname)
    save_dict = mx.nd.load(p_fname)
    arg_params = {}
    aux_params = {}
    for k, v in save_dict.items():
        tp, name = k.split(':', 1)
        if tp == 'arg':
            arg_params[name] = v
        if tp == 'aux':
            aux_params[name] = v
    return symbol, arg_params, aux_params

# load labels
with open('grids.txt', 'r') as f:
    for line in f:
        line = line.strip().split('\t')
        lat = float(line[1])
        lng = float(line[2])
        grids.append((lat, lng))
   
mod = None

with tempfile.NamedTemporaryFile(delete=True) as f_params_file, tempfile.NamedTemporaryFile(delete=True) as f_symbol_file:

    s3_client.download_file(bucket, f_params, f_params_file.name)
    f_params_file.flush()

    s3_client.download_file(bucket, f_symbol, f_symbol_file.name)
    f_symbol_file.flush()

    sym, arg_params, aux_params = load_model(f_symbol_file.name, f_params_file.name)
    mod = mx.mod.Module(symbol=sym)
    mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))])
    mod.set_params(arg_params, aux_params)


### Helpers
def distance(p1, p2):
        R = 6371 # Earth radius in km
        lat1, lng1, lat2, lng2 = map(radians, (p1[0], p1[1], p2[0], p2[1]))
        dlat = lat2 - lat1
        dlng = lng2 - lng1
        a = sin(dlat * 0.5) ** 2 + cos(lat1) * cos(lat2) * (sin(dlng * 0.5) ** 2)
        return 2 * R * asin(sqrt(a))

# mean image for preprocessing
mean_rgb = np.array([123.68, 116.779, 103.939])
mean_rgb = mean_rgb.reshape((3, 1, 1))

def predict(url, dataurl):
    '''
    predict labels for a given image
    '''

    img_file = tempfile.NamedTemporaryFile()
    if url:
        req = urllib2.urlopen(url)
        img_file.write(req.read())
        img_file.flush()
        img = Image.open(img_file.name)
    elif dataurl:
        #convert to image
        img_data = dataurl.split(",")[1]
        if img_data[-2] != "=":
            img_data += "=" # pad it 
        img = Image.open(BytesIO(base64.b64decode(img_data))) 
        img = img.convert('RGB')

    # center crop and no resize
    # ** width, height must be greater than new_width, new_height 
    #new_width, new_height = 224, 224
    #width, height = img.size   # Get dimensions
    #left = (width - new_width)/2
    #top = (height - new_height)/2
    #right = (width + new_width)/2
    #bottom = (height + new_height)/2
    #img = img.crop((left, top, right, bottom))

    # preprocess by cropping to shorter side and then resize
    short_side = min(img.size)
    left = int((img.size[0] - short_side) / 2)
    right = left + short_side
    top = int((img.size[1] - short_side) / 2)
    bottom = top + short_side
    img = img.crop((left, top, right, bottom))
    img = img.resize((224, 224), Image.ANTIALIAS)

    # convert to numpy.ndarray
    sample = np.asarray(img)  
    # swap axes to make image from (224, 224, 3) to (3, 224, 224)
    sample = np.swapaxes(sample, 0, 2)
    sample = np.swapaxes(sample, 1, 2)
    sample = sample[np.newaxis, :] 
    print sample.shape

    # sub mean? 
    normed_img = sample - mean_rgb
    normed_img = normed_img.reshape((1, 3, 224, 224))

    mod.forward(Batch([mx.nd.array(normed_img)]), is_train=False)
    prob = mod.get_outputs()[0].asnumpy()[0]
    pred = np.argsort(prob)[::-1]
    idx = pred[0]
    lat, lng = grids[idx] #top result
    # lat, lng
    return lat, lng

def map_location_to_destination(latlng):
    #TODO: Implement a convinient reverse geolocation API
    return '{"city": "%s", "country": "%s"}' % (loc[0], loc[1]),  

def lambda_handler(event, context):

    url = None 
    data_url = None

    try:
        # API Gateway GET method
        print "Request Method:", event['httpMethod']
        if event['httpMethod'] == 'GET':
            url = event['queryStringParameters']['url']
        #API Gateway POST method
        elif event['httpMethod'] == 'POST':
            data = json.loads(event['body'])
            if data.has_key('dataurl'):
                data_url = data['dataurl']
            else:
                url = data['url']
            
    except KeyError:
        # direct invocation
        url = event['url']

    print "URL:" , url

    lat, lng = predict(url, data_url)

    #latlng = "%s,%s" % (lat,lng)
    #loc = map_location_to_destination(latlng)
    out = {
            "headers": {
                "content-type": "application/json",
                "Access-Control-Allow-Origin": "*"
                },
            "body": '{"Lattitude": "%s", "Longitude": "%s"}' % (lat, lng),  
            "statusCode": 200
          }
    return out


just exit, we are not in a lambda function
An exception has occurred, use %tb to see the full traceback.

SystemExit: 0
/usr/local/lib/python2.7/site-packages/IPython/core/interactiveshell.py:2889: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.
  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)

In [53]:
# Package code and upload 

# refer for more info: http://ipython.readthedocs.io/en/stable/interactive/reference.html?highlight=input%20caching

content = _ih[-2] # Get the input from the previous cell execution 
fname = "lambda_function.py"

with open(fname, 'w') as f:
    f.write(content)

In [54]:
# NOTE: Pre-requisities (AWS CLI) and appropriate config credentials 
code_zip_name = 'mxnet_lambda_code.zip'

# Create a zip file with all of MXNet dependencies (-Fsr : Only sync updated files)
!zip -9r -FSr $code_zip_name * -x *.zip


updating: geolocation_lambda_sam.ipynb (deflated 74%)
updating: swagger.yaml (deflated 78%)
updating: template.yaml (deflated 56%)
  adding: lambda_function.py (deflated 64%)
  adding: template-out.yaml (deflated 60%)

Swagger File


In [55]:
account_id = 'MY_ACC_ID' # <== Substitute your account ID
region = 'us-west-2' # <== Update your region

!sed -e 's/<<region>>/$region/g' swagger.yaml.template > swagger.yaml
!sed -i -e 's/<<account-id>>/$account_id/g' swagger.yaml

Upload Code and YAML files to S3


In [56]:
bucket_loc = "s3://smallya-testw/samtest/" # **NOTE** Make sure bucket is in the same region as region above
!aws s3 cp $code_zip_name $bucket_loc
!aws s3 cp swagger.yaml $bucket_loc


upload: ./mxnet_lambda_code.zip to s3://smallya-testw/samtest/mxnet_lambda_code.zip
upload: ./swagger.yaml to s3://smallya-testw/samtest/swagger.yaml

Template File


In [57]:
definition_url = bucket_loc + 'swagger.yaml' # swagger file location in s3
code_uri = bucket_loc + code_zip_name  # code location in s3
definition_url = definition_url.replace(':', '\:').replace('/', '\/')
code_uri = code_uri.replace(':', '\:').replace('/', '\/')

!sed -e 's/<<def-uri>>/$definition_url/g' template.yaml.template > template.yaml
!sed -i -e 's/<<code-uri>>/$code_uri/g' template.yaml

Deply using SAM


In [58]:
!aws cloudformation package \
 --template-file template.yaml \
 --output-template-file template-out.yaml \
 --s3-bucket $bucket_loc


Successfully packaged artifacts and wrote output template to file template-out.yaml.
Execute the following command to deploy the packaged template
aws cloudformation deploy --template-file /home/ubuntu/workspace/geolocation/lambda-sam/template-out.yaml --stack-name <YOUR STACK NAME>

In [59]:
stack_name = "MX-LAMBDA-TEST"

!aws cloudformation deploy \
--template-file template-out.yaml \
--stack-name $stack_name \
--capabilities CAPABILITY_IAM \
--region $region


Waiting for changeset to be created..
Waiting for stack create/update to complete
Successfully created/updated stack - MX-LAMBDA-TEST

In [1]:
api_endpoint = !aws cloudformation describe-stacks --stack-name $stack_name --region $region | python -c 'import json,sys;obj=json.load(sys.stdin);print obj["Stacks"][0]["Outputs"][0]["OutputValue"];'
print api_endpoint


['usage: aws [options] <command> <subcommand> [<subcommand> ...] [parameters]', 'To see help text, you can run:', '', '  aws help', '  aws <command> help', '  aws <command> <subcommand> help', 'aws: error: argument --region: expected one argument', 'Traceback (most recent call last):', '  File "<string>", line 1, in <module>', '  File "/usr/lib/python2.7/json/__init__.py", line 290, in load', '    **kw)', '  File "/usr/lib/python2.7/json/__init__.py", line 338, in loads', '    return _default_decoder.decode(s)', '  File "/usr/lib/python2.7/json/decoder.py", line 366, in decode', '    obj, end = self.raw_decode(s, idx=_w(s, 0).end())', '  File "/usr/lib/python2.7/json/decoder.py", line 384, in raw_decode', '    raise ValueError("No JSON object could be decoded")', 'ValueError: No JSON object could be decoded']


In [127]:
import requests
import json

img_url = 'https://www.svalbardblues.com/wp-content/uploads/2015/09/Longyearbyen-Svalbard-Spitsbergen-DSB.jpg'
api_endpoint = ['https://udgz5whroh.execute-api.us-east-1.amazonaws.com/prod']
url = api_endpoint[0]+ "/predict?url=" + img_url

# Lets curl and test the endpoint
!curl $url


{"city": "Svalbard", "country": "Svalbard and Jan Mayen"}


In [2]:
import requests
import json

img_url = 'http://www.liveroof.com/wp-content/uploads/2015/01/Javits-Center_Exterior-Dusk_DS1-990x360.jpg'
api_endpoint = ['https://udgz5whroh.execute-api.us-east-1.amazonaws.com/prod']
url = api_endpoint[0]+ "/predict?url=" + img_url

# Lets curl and test the endpoint
!curl $url


{"city": "New York", "country": "United States"}

In [ ]: