Predicting babyweight using BigQuery ML

This notebook illustrates:

  1. Machine Learning using BigQuery
  2. Jupyter Magic for BigQuery in Cloud Datalab

Please see this notebook for more context on this problem and how the features were chosen.


In [1]:
# change these to try this notebook out
PROJECT = 'cloud-training-demos'
REGION = 'us-central1'

In [2]:
import os
os.environ['PROJECT'] = PROJECT
os.environ['REGION'] = REGION

In [3]:
%%bash
gcloud config set project $PROJECT
gcloud config set compute/region $REGION


Updated property [core/project].
Updated property [compute/region].

Exploring the Data

Here, we will be taking natality data and training on features to predict the birth weight.

The CDC's Natality data has details on US births from 1969 to 2008 and is available in BigQuery as a public data set. More details: https://bigquery.cloud.google.com/table/publicdata:samples.natality?tab=details

Lets start by looking at the data since 2000 with useful values > 0!


In [4]:
%%bigquery
SELECT
    *
FROM
  publicdata.samples.natality
WHERE
  year > 2000
  AND gestation_weeks > 0
  AND mother_age > 0
  AND plurality > 0
  AND weight_pounds > 0
LIMIT 10


Out[4]:
source_yearyearmonthdaywdaystateis_malechild_raceweight_poundspluralityapgar_1minapgar_5minmother_residence_statemother_racemother_agegestation_weekslmpmother_marriedmother_birth_statecigarette_usecigarettes_per_dayalcohol_usedrinks_per_weekweight_gain_poundsborn_alive_aliveborn_alive_deadborn_deadever_bornfather_racefather_agerecord_weight
200120018 4NYTrue98.31363190001999919910NY2204410012000FalseNY  False 4600012211
200120013 5FLTrue98.249697844041999FL1244199999999TrueForeign  False 9900011261
200120015 4FLTrue94.312241844721999FL1193808992000TrueMA  False 2500011211
200120011 2MOTrue98.3753613333799991999MO1304004052000TrueMO  False 2710021311
2001200112 4ILTrue99.250596513521999IL1284102222001TrueIL  False 4010221281
200120012 4NYFalse96.060507582381999NY1313805172000TrueNY  False 1410021311
200120014 2OHFalse97.35241643771999OH1243907142000TrueKY  False 1310121241
2001200111 2MIFalse96.7505544624419910MI1354202012001TrueMI  False 1520131301
2001200111 5CAFalse96.1310555062219999CA1323802142001TrueCA    9920031221
200120013 1ILFalse910.374954049721999IL1343906162000TrueIL  False 4230141331

(rows: 10, time: 1.9s, 23GB processed, job: job_Jre7EM0iWleBumtVWKqFL6hhPEQ7)

Define Features

Looking over the data set, there are a few columns of interest that could be leveraged into features for a reasonable prediction of approximate birth weight.

Further, some feature engineering may be accomplished with the BigQuery CAST function -- in BQML, all strings are considered categorical features and all numeric types are considered continuous ones.

The hashmonth is added so that we can repeatably split the data without leakage -- we want all babies that share a birthday to be either in training set or in test set and not spread between them (otherwise, there would be information leakage when it comes to triplets, etc.)


In [5]:
%%bigquery
SELECT
    weight_pounds, -- this is the label; because it is continuous, we need to use regression
    CAST(is_male AS STRING) AS is_male,
    mother_age,
    CAST(plurality AS STRING) AS plurality,
    gestation_weeks,
    FARM_FINGERPRINT(CONCAT(CAST(YEAR AS STRING), CAST(month AS STRING))) AS hashmonth
FROM
  publicdata.samples.natality
WHERE
  year > 2000
  AND gestation_weeks > 0
  AND mother_age > 0
  AND plurality > 0
  AND weight_pounds > 0
LIMIT 10


Out[5]:
weight_poundsis_malemother_agepluralitygestation_weekshashmonth
6.686620406459999true181438904940584331855459
9.36082764452true321411088037545023002395
6.9996768185true231401088037545023002395
9.37405538024true341401525201076796226340
8.37315671076true331403408502330831153141
8.437090766739999false301395896567601480310696
6.1244416383599996false241406244544205302024223
7.12534030784false261418029892925374153452
6.944561253false311402126480030009879160
7.1870697412false231401403073183891835564

(rows: 10, time: 1.2s, 6GB processed, job: job__zp60pgtV6soL1xBkFRDRng7iIJ3)

Train Model

With the relevant columns chosen to accomplish predictions, it is then possible to create (train) the model in BigQuery. First, a dataset will be needed store the model. (if this throws an error in Datalab, simply create the dataset from the BigQuery console).


In [ ]:
%%bash
bq --location=US mk -d demo

With the demo dataset ready, it is possible to create a linear regression model to train the model.

This will take approximately 4 minutes to run and will show Done when complete.


In [ ]:
%%bigquery
CREATE or REPLACE MODEL demo.babyweight_model_asis
OPTIONS
  (model_type='linear_reg', labels=['weight_pounds']) AS
  
WITH natality_data AS (
  SELECT
    weight_pounds,-- this is the label; because it is continuous, we need to use regression
    CAST(is_male AS STRING) AS is_male,
    mother_age,
    CAST(plurality AS STRING) AS plurality,
    gestation_weeks,
    FARM_FINGERPRINT(CONCAT(CAST(YEAR AS STRING), CAST(month AS STRING))) AS hashmonth
  FROM
    publicdata.samples.natality
  WHERE
    year > 2000
    AND gestation_weeks > 0
    AND mother_age > 0
    AND plurality > 0
    AND weight_pounds > 0
)

SELECT
    weight_pounds,
    is_male,
    mother_age,
    plurality,
    gestation_weeks
FROM
    natality_data
WHERE
  ABS(MOD(hashmonth, 4)) < 3  -- select 75% of the data as training

Training Statistics

During the model training (and after the training), it is possible to see the model's training evaluation statistics.

For each training run, a table named <model_name>_eval is created. This table has basic performance statistics for each iteration.

While the new model is training, review the training statistics in the BigQuery UI to see the below model training: https://bigquery.cloud.google.com/

Since these statistics are updated after each iteration of model training, you will see different values for each refresh while the model is training.

The training details may also be viewed after the training completes from this notebook.


In [12]:
%%bigquery
SELECT * FROM ML.TRAINING_INFO(MODEL demo.babyweight_model_asis);


Out[12]:
training_runiterationlosseval_lossduration_mslearning_rate
051.130821756491.12667393796982050.4
041.132422584211.12847673261262550.8
031.143552633421.1400193056991660.4
021.179054974721.176291379961057570.4
011.572863363181.56866519873981850.4
009.855743484359.86270726649963820.2

(rows: 6, time: 1.3s, 0B processed, job: job_No9S9g6EeX4EdQgdOtcF538SKZ2w)

Some of these columns are obvious although what do the non-specific ML columns mean (specific to BQML)?

training_run - Will be zero for a newly created model. If the model is re-trained using warm_start, this will increment for each re-training.

iteration - Number of the associated training_run, starting with zero for the first iteration.

duration_ms - Indicates how long the iteration took (in ms).

Note: You can also see these stats by refreshing the BigQuery UI window, finding the <model_name> table, selecting on it, and then the Training Stats sub-header.

Let's plot the training and evaluation loss to see if the model has an overfit.


In [2]:
import google.datalab.bigquery as bq
df = bq.Query("SELECT * FROM ML.TRAINING_INFO(MODEL demo.babyweight_model_asis)").execute().result().to_dataframe()
# plot both lines in same graph
import matplotlib.pyplot as plt
plt.plot( 'iteration', 'loss', data=df, marker='o', color='orange', linewidth=2)
plt.plot( 'iteration', 'eval_loss', data=df, marker='', color='green', linewidth=2, linestyle='dashed')
plt.xlabel('iteration')
plt.ylabel('loss')
plt.legend();


/usr/local/envs/py3env/lib/python3.5/site-packages/matplotlib/font_manager.py:1320: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
  (prop.get_family(), self.defaultFamily[fontext]))

As you can see, the training loss and evaluation loss are essentially identical. We do not seem to be overfitting.

Make a Prediction with BQML using the Model

With a trained model, it is now possible to make a prediction on the values. The only difference from the second query above is the reference to the model. The data has been limited (LIMIT 100) to reduce amount of data returned.

When the ml.predict function is leveraged, output prediction column name for the model is predicted_<label_column_name>.


In [14]:
%%bigquery
SELECT
  *
FROM
  ml.PREDICT(MODEL demo.babyweight_model_asis,
      (SELECT
        weight_pounds,
        CAST(is_male AS STRING) AS is_male,
        mother_age,
        CAST(plurality AS STRING) AS plurality,
        gestation_weeks
      FROM
        publicdata.samples.natality
      WHERE
        year > 2000
        AND gestation_weeks > 0
        AND mother_age > 0
        AND plurality > 0
        AND weight_pounds > 0
    ))
LIMIT 100


Out[14]:
predicted_weight_poundsweight_poundsis_malemother_agepluralitygestation_weeks
3.927582207373.62439958728true26125
5.917944356594.46877005074true24133
5.897047280354.2108292042true23133
6.405086355778.31363190002true23135
6.425983432017.05920162924true24135
6.763591274676.062712205true28136
6.55462051235.74965579296true18136
7.059404964867.3524164377true30137
7.017610812386.6248909731true28137
6.850434202496.7571683303true20137
7.059404964867.68751907594true30137
7.480601112467.3744626639true38138
7.250733273869.81277528162true27138
7.146247892678.12623897732true22138
7.125350816448.81187661214true21138
7.355218655046.91369653632true32138
7.229836197628.68841774542true26138
7.292527426337.1870697412true29138
7.10445374026.4374980504true20138
7.146247892676.9996768185true22138
7.250733273866.67118804812true27138
7.188042045156.6248909731true24138
7.29578204927.25100379718true17139
7.504752811577.5618555866true27139
7.69282649777.87491199864true36139

(rows: 100, time: 1.8s, 0B processed, job: job_e39HFO9u7Ccr-sbTPhIGKvYcODKl)

More advanced...

In the original example, we were taking into account the idea that if no ultrasound has been performed, some of the features (e.g. is_male) will not be known. Therefore, we augmented the dataset with such masked features and trained a single model to deal with both these scenarios.

In addition, during data exploration, we learned that the data size set for mothers older than 45 was quite sparse, so we will discretize the mother age.


In [15]:
%%bigquery
SELECT
    weight_pounds,
    CAST(is_male AS STRING) AS is_male,
    IF(mother_age < 18, 'LOW',
         IF(mother_age > 45, 'HIGH',
            CAST(mother_age AS STRING))) AS mother_age,
    CAST(plurality AS STRING) AS plurality,
    CAST(gestation_weeks AS STRING) AS gestation_weeks,
    FARM_FINGERPRINT(CONCAT(CAST(YEAR AS STRING), CAST(month AS STRING))) AS hashmonth
  FROM
    publicdata.samples.natality
  WHERE
    year > 2000
    AND gestation_weeks > 0
    AND mother_age > 0
    AND plurality > 0
    AND weight_pounds > 0
LIMIT 25


Out[15]:
weight_poundsis_malemother_agepluralitygestation_weekshashmonth
8.8074673669true391421088037545023002395
7.3744626639true381381088037545023002395
8.1350574678true201396244544205302024223
7.25100379718trueLOW1391525201076796226340
6.25671899556true291417146494315947640619
7.62578964258true221416392072535155213407
7.3524164377true301378904940584331855459
9.0940683075true251408904940584331855459
7.87491199864true361397170969733900686954
7.5618555866true271396691862025345277042
8.31363190002true231352126480030009879160
8.06230492134trueLOW1402126480030009879160
4.46877005074true241332126480030009879160
8.12623897732true321402126480030009879160
6.062712205true281367108882242435606404
7.56846945446true221461403073183891835564
7.12534030784true201409068386407968572094
7.43839671988false311386392072535155213407
5.43659938092false352346392072535155213407
6.18837569434false201408904940584331855459
8.06230492134false251408904940584331855459
8.00057548798false271407170969733900686954
6.0075966395false271397170969733900686954
7.12534030784false341407108882242435606404
6.56316153974false291393408502330831153141

(rows: 25, time: 1.6s, 6GB processed, job: job_0XT7nQDNhs0gRi-S33srgx_5ruwO)

On the same dataset, will also suppose that it is unknown whether the child is male or female (on the same dataset) to simulate that an ultrasound was not been performed.


In [16]:
%%bigquery
SELECT
    weight_pounds,
    'Unknown' AS is_male,
    IF(mother_age < 18, 'LOW',
         IF(mother_age > 45, 'HIGH',
            CAST(mother_age AS STRING))) AS mother_age,
    IF(plurality > 1, 'Multiple', 'Single') AS plurality,
    CAST(gestation_weeks AS STRING) AS gestation_weeks,
    FARM_FINGERPRINT(CONCAT(CAST(YEAR AS STRING), CAST(month AS STRING))) AS hashmonth
  FROM
    publicdata.samples.natality
  WHERE
    year > 2000
    AND gestation_weeks > 0
    AND mother_age > 0
    AND plurality > 0
    AND weight_pounds > 0
LIMIT 25


Out[16]:
weight_poundsis_malemother_agepluralitygestation_weekshashmonth
7.06140625186Unknown34Single371088037545023002395
6.9996768185Unknown23Single401088037545023002395
9.36082764452Unknown32Single411088037545023002395
6.12444163836Unknown24Single406244544205302024223
9.37405538024Unknown34Single401525201076796226340
6.2501051277Unknown30Single388904940584331855459
7.1870697412Unknown34Single398904940584331855459
6.68662040646Unknown18Single438904940584331855459
7.8153871879Unknown28Single387170969733900686954
7.62578964258Unknown20Single346691862025345277042
6.944561253Unknown31Single402126480030009879160
6.6248909731Unknown35Single402126480030009879160
6.9996768185Unknown37Single402126480030009879160
7.50012615324UnknownLOWSingle407108882242435606404
7.50012615324Unknown30Single385896567601480310696
8.43709076674Unknown30Single395896567601480310696
7.40532738058Unknown22Single395896567601480310696
7.936641432Unknown33Single391403073183891835564
7.87491199864Unknown27Single391403073183891835564
7.1870697412Unknown23Single401403073183891835564
6.93574276252Unknown32Single411403073183891835564
8.24969784404Unknown24Single399068386407968572094
7.12534030784Unknown26Single418029892925374153452
6.9996768185Unknown34Single383408502330831153141
8.37315671076Unknown33Single403408502330831153141

(rows: 25, time: 1.4s, 6GB processed, job: job_hGtrg2opHeilvzVEsw9mNEtKNCzv)

Bringing these two separate data sets together, there is now a dataset for male or female children determined with ultrasound or unknown if without.


In [17]:
%%bigquery
WITH with_ultrasound AS (
  SELECT
    weight_pounds,
    CAST(is_male AS STRING) AS is_male,
    IF(mother_age < 18, 'LOW',
         IF(mother_age > 45, 'HIGH',
            CAST(mother_age AS STRING))) AS mother_age,
    CAST(plurality AS STRING) AS plurality,
    CAST(gestation_weeks AS STRING) AS gestation_weeks,
    FARM_FINGERPRINT(CONCAT(CAST(YEAR AS STRING), CAST(month AS STRING))) AS hashmonth
  FROM
    publicdata.samples.natality
  WHERE
    year > 2000
    AND gestation_weeks > 0
    AND mother_age > 0
    AND plurality > 0
    AND weight_pounds > 0
),

without_ultrasound AS (
  SELECT
    weight_pounds,
    'Unknown' AS is_male,
    IF(mother_age < 18, 'LOW',
         IF(mother_age > 45, 'HIGH',
            CAST(mother_age AS STRING))) AS mother_age,
    IF(plurality > 1, 'Multiple', 'Single') AS plurality,
    CAST(gestation_weeks AS STRING) AS gestation_weeks,
    FARM_FINGERPRINT(CONCAT(CAST(YEAR AS STRING), CAST(month AS STRING))) AS hashmonth
  FROM
    publicdata.samples.natality
  WHERE
    year > 2000
    AND gestation_weeks > 0
    AND mother_age > 0
    AND plurality > 0
    AND weight_pounds > 0
),

preprocessed AS (
  SELECT * from with_ultrasound
  UNION ALL
  SELECT * from without_ultrasound
)

SELECT
    weight_pounds,
    is_male,
    mother_age,
    plurality,
    gestation_weeks
FROM
    preprocessed
WHERE
  ABS(MOD(hashmonth, 4)) < 3
LIMIT 25


Out[17]:
weight_poundsis_malemother_agepluralitygestation_weeks
4.68702769012Unknown30Multiple33
7.06361087448Unknown32Single37
7.5618555866Unknown31Single37
7.25100379718Unknown33Single37
5.8312268299Unknown27Single37
8.75014717878Unknown24Single38
8.8736060455Unknown30Single38
6.05389371452UnknownLOWSingle38
7.50012615324Unknown23Single39
6.93794738514Unknown23Single39
10.7254890463Unknown28Single39
8.6200744442Unknown31Single39
7.89034435698Unknown31Single39
6.062712205Unknown19Single39
7.31273323054Unknown32Single40
7.6279942652Unknown30Single40
7.7492485093Unknown22Single40
8.75014717878Unknown34Single40
7.7492485093Unknown30Single40
6.2280589015Unknown18Single40
7.12534030784Unknown25Single41
6.2501051277Unknown28Single41
7.7602716224Unknown34Single43
6.9114919137Unknown24Single45
5.62399230362Unknown29Single47

(rows: 25, time: 1.6s, 6GB processed, job: job_QE5oV7jaNvRyUgjJmWGu53cOiywf)

Create a new model

With a data set which has been feature engineered, it is ready to create model with the CREATE or REPLACE MODEL statement

This will take 5-10 minutes and will show Done when complete.


In [18]:
%%bigquery
CREATE or REPLACE MODEL demo.babyweight_model_fc
OPTIONS
  (model_type='linear_reg', labels=['weight_pounds']) AS
  
WITH with_ultrasound AS (
  SELECT
    weight_pounds,
    CAST(is_male AS STRING) AS is_male,
    IF(mother_age < 18, 'LOW',
         IF(mother_age > 45, 'HIGH',
            CAST(mother_age AS STRING))) AS mother_age,
    CAST(plurality AS STRING) AS plurality,
    CAST(gestation_weeks AS STRING) AS gestation_weeks,
    FARM_FINGERPRINT(CONCAT(CAST(YEAR AS STRING), CAST(month AS STRING))) AS hashmonth
  FROM
    publicdata.samples.natality
  WHERE
    year > 2000
    AND gestation_weeks > 0
    AND mother_age > 0
    AND plurality > 0
    AND weight_pounds > 0
),

without_ultrasound AS (
  SELECT
    weight_pounds,
    'Unknown' AS is_male,
    IF(mother_age < 18, 'LOW',
         IF(mother_age > 45, 'HIGH',
            CAST(mother_age AS STRING))) AS mother_age,
    IF(plurality > 1, 'Multiple', 'Single') AS plurality,
    CAST(gestation_weeks AS STRING) AS gestation_weeks,
    FARM_FINGERPRINT(CONCAT(CAST(YEAR AS STRING), CAST(month AS STRING))) AS hashmonth
  FROM
    publicdata.samples.natality
  WHERE
    year > 2000
    AND gestation_weeks > 0
    AND mother_age > 0
    AND plurality > 0
    AND weight_pounds > 0
),

preprocessed AS (
  SELECT * from with_ultrasound
  UNION ALL
  SELECT * from without_ultrasound
)

SELECT
    weight_pounds,
    is_male,
    mother_age,
    plurality,
    gestation_weeks
FROM
    preprocessed
WHERE
  ABS(MOD(hashmonth, 4)) < 3


Out[18]:
Done

Training Statistics

While the new model is training, review the training statistics in the BigQuery UI to see the below model training: https://bigquery.cloud.google.com/

The training details may also be viewed after the training completes from this notebook.


In [19]:
import google.datalab.bigquery as bq
df = bq.Query("SELECT * FROM ML.TRAINING_INFO(MODEL demo.babyweight_model_fc)").execute().result().to_dataframe()
# plot both lines in same graph
import matplotlib.pyplot as plt
plt.plot( 'iteration', 'loss', data=df, marker='o', color='orange', linewidth=2)
plt.plot( 'iteration', 'eval_loss', data=df, marker='', color='green', linewidth=2, linestyle='dashed')
plt.xlabel('iteration')
plt.ylabel('loss')
plt.legend();


Make a prediction with the new model

Perhaps it is of interest to make a prediction of the baby's weight given a number of other factors: Male, Mother is 28 years old, Mother will only have one child, and the baby was born after 38 weeks of pregnancy.

To make this prediction, these values will be passed into the SELECT statement.


In [20]:
%%bigquery
SELECT
  *
FROM
  ml.PREDICT(MODEL demo.babyweight_model_fc,
      (SELECT
          'True' AS is_male,
          '28' AS mother_age,
          '1' AS plurality,
          '38' AS gestation_weeks
    ))


Out[20]:
predicted_weight_poundsis_malemother_agepluralitygestation_weeks
5.85625668152True28138

(rows: 1, time: 1.4s, 0B processed, job: job__WDzbusrFc-NctUS5cpZSaBjWZD-)





Copyright 2018 Google Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License