Understanding Electronic Health Records with BigQuery ML

This tutorial introduces BigQuery ML (BQML) in the context of working with the MIMIC3 dataset.

BigQuery ML adds only a few statements to standard SQL. These statements automate the creation and evaluation of statistical models on BigQuery datasets. BigQuery ML has several advantages over older machine learning tools and workflows. Some highlights are BQML's high performance on massive datasets, support for HIPAA compliance, and ease of use. BQML automatically implements state of the art best practices in machine learning for your dataset.

MIMIC3 is a 10-year database of health records from the intensive care unit of Beth Israel Deaconess Medical Center in Boston. It's full of insights that are just begging to be uncovered.

Table of Contents

Setup

Covers importing libraries, and authenticating with Google Cloud in Colab.

Case complexity & mortality

Non-technical. Introduces the theme for this tutorial.

Taking a first look at the data

Covers basic SQL syntax, how BigQuery integrates with Colab and pandas, and the basics of creating visualizations with seaborn.

Creating a classification model

Covers creating and training simple models with BigQuery ML.

Plotting the predictions

Covers inference (making predictions) with BigQuery ML models, and how to inspect the weights of a parametric model.

Adding a confounding variable

Covers creating and training a slightly more complicated model, and introduces how BigQuery ML's model comparison features can be used to address confounding relationships.

Plotting ROC and precision-recall curves

Covers how to create ROC and precision-recall curves with BigQuery ML. These are visualizations that describe the performance of binary classification models .

More complex models

Creating the models

Covers creating logistic regression models with many input variables.

Getting evaluation metrics

Covers how to get numerical measures of model performance using BigQuery ML.

Exploring our model

Demonstrates how to interpret models with many variables.

Conclusion

Non-technical. Looks back on how we have used BigQuery ML to answer a research question.

Setup

First, you'll need to sign into your google account to access the Google Cloud Platform (GCP).

We're also going to import some standard python data analysis packages that we'll use later to visualize our models.


In [0]:
from __future__ import print_function
from google.colab import auth
from google.cloud import bigquery
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [0]:
auth.authenticate_user()

Next you'll need to enter some information on how to access the data.

analysis_project is the project used for processing the queries.

The other fields, admissions_table, d_icd_diagnoses_table, diagnoses_icd_table, and patients_table, identify the BigQuery tables we're going to query. They're written in the form "project_id.dataset_id.table_id". We're going to use a slightly modified version of the %%bigquery cell magic in this tutorial, which replaces these variables with their values whenever they're surrounded by curly-braces.


In [0]:
#@title Fill out this form then press [shift ⇧]+[enter ⏎] {run: "auto"}
import subprocess
import re

analysis_project = 'your-analysis-project'  #@param {type:"string"}

admissions_table = 'physionet-data.mimiciii_clinical.admissions'  # @param {type: "string"}
d_icd_diagnoses_table = 'physionet-data.mimiciii_clinical.d_icd_diagnoses'  # @param {type: "string"}
diagnoses_icd_table = 'physionet-data.mimiciii_clinical.diagnoses_icd'  # @param {type: "string"}
patients_table = 'physionet-data.mimiciii_clinical.patients'  # @param {type: "string"}

# Preprocess queries made with the %%bigquery magic
# by substituting these values
sub_dict = {
    'admissions_table': admissions_table,
    'd_icd_diagnoses_table': d_icd_diagnoses_table,
    'diagnoses_icd_table': diagnoses_icd_table,
    'patients_table': patients_table
}

# Get a suffix to attach to the names of the models created during this tutorial
# to avoid collisions between simultaneous users.
account = subprocess.check_output(
    ['gcloud', 'config', 'list', 'account', '--format',
     'value(core.account)']).decode().strip()
sub_dict['suffix'] = re.sub(r'[^\w]', '_', account)[:900]

# Set the default project for running queries
bigquery.magics.context.project = analysis_project

# Set up the substitution preprocessing injection
if bigquery.magics._run_query.func_name != 'format_and_run_query':
  original_run_query = bigquery.magics._run_query

def format_and_run_query(client, query, job_config=None):
  query = query.format(**sub_dict)
  return original_run_query(client, query, job_config)

bigquery.magics._run_query = format_and_run_query

print('analysis_project:', analysis_project)
print()
print('custom %%bigquery magic substitutions:')
for k, v in sub_dict.items():
  print(' ', '{%s}' % k, '→', v)

In [0]:
%config InlineBackend.figure_format = 'svg'

In [0]:
bq = bigquery.Client(project=analysis_project)

Case complexity & mortality

This tutorial is a case study. We're going to use BQML and MIMIC3 to answer a research question.

In the intensive care unit, are complex cases more or less likely to be fatal?

Maybe it's obvious that they would be more fatal. After all, things only get worse as you add more comorbidities. Or maybe the exact opposite is true. Compare the patient who comes to the ICU with ventricular fibrillation to the patient who comes with a laundry list of chronic comorbidities. Especially within the context of a particular admission, the single acute condition seems more lethal.

Taking a first look at the data

Do we have the data to answer this question?

If you browse through the list of tables in the MIMIC dataset, you'll find that whether the patient passed away during the course of their admission is recorded. We can also operationalize the definition of case complexity by counting the number of diagnoses that the patient had during an admission. More diagnoses means greater case complexity.

We need to check that we have a sufficiently diverse sample to build a viable model. First we'll check our dependent variable, which measures whether a patient passed away.


In [0]:
%%bigquery
SELECT
  COUNT(*) as total,
  SUM(HOSPITAL_EXPIRE_FLAG) as died
FROM
  `{admissions_table}`

Clearly the ICU is a very serious place: about 10% of admissions are mortal. As data scientists, this tells us that we have a significant, albeit imbalanced, number of samples in both categories. The models we're training will easily adapt to this class imbalance, but we will need to be cautious when evaluating the performance of our models. After all, a model that simply says "no one dies" will be right 91% of the time.

Next we'll look at the distribution of our independent variable: the number of diagnoses assigned to a patient during their admission.


In [0]:
%%bigquery hist_df
SELECT
  n_diagnoses, COUNT(*) AS cnt
FROM (
  SELECT
    COUNT(*) AS n_diagnoses
  FROM
    `{diagnoses_icd_table}`
  GROUP BY
    HADM_ID
)
GROUP BY n_diagnoses
ORDER BY n_diagnoses

In [0]:
g = sns.barplot(
    x=hist_df.n_diagnoses, y=hist_df.cnt, color=sns.color_palette()[0])
# Remove every fifth label on the x-axis for readability
for i, label in enumerate(g.get_xticklabels()):
  if i % 5 != 4 and i != 0:
    label.set_visible(False)

With the exception of the dramatic mode¹, the spread of the diagnosis counts is bell-curved shaped. The mathematical explanation of this is called central limit theorem. While this is by no means a deal breaker, the thins tails we see in the distribution can be a challenge for linear-regression models. This is because the extreme points tend to affect the likelihood the most, so having fewer of them makes your model more sensitive to outliers. Regularization can help with this, but if it becomes too much of a problem we can consider a different type of model (such as support-vector machines, or robust regression) instead of generalized linear regression.


¹ Which is sort of fascinating. Comparing the most common diagnoses for admissions with exactly 9 diagnoses to the rest of the cohort seems to suggest that this is due to positive correlations between cardiac diagnoses, e.g. cardiac complications NOS, mitral valve disorders, aortic valve disorders, subendocardial infarction etc. Your team might be interested in investigating this more seriously, especially if there is a cardiologist among you.

Creating a classification model

Creating a model with BigQuery ML is simple. You write a normal query in standard SQL, and each row of the result is used as an input to train your model. BigQuery ML automatically applies the required transformations depending on each variable's data type. For example, STRINGs are transformed into one-hot vectors, and TIMESTAMPs are standardized. These transformations are necessary to get a valid result, but they're easy to forget and a pain to implement. Without BQML, you also have to remember to apply these transformations when you make predictions and plots. It's fantastic that BigQuery takes care of all this for you.

BigQuery ML also automatically performs validation-based early stopping to prevent overfitting.

To start, we're going to create a (regularized) logistic regression model that uses a single variable, the number of diagnoses a patient had during an admission, to predict the probability that a patient will pass away during an ICU admission.


In [0]:
%%bigquery
# BigQuery ML create model statement:
CREATE OR REPLACE MODEL `mimic_models.complexity_mortality_{suffix}`
OPTIONS(
  # Use logistic_reg for discrete predictions (classification) and linear_reg
  # for continuous predictions (forecasting).
  model_type = 'logistic_reg',
  # See the below aside (𝜎 = 0.5 ⇒ 𝜆 = 2)
  l2_reg = 2,
  # Identify the column to use as the label (dependent variable)
  input_label_cols = ["died"]
)
AS
# standard SQL query to train the model with:
SELECT
  COUNT(*) AS number_of_diagnoses,
  MAX(HOSPITAL_EXPIRE_FLAG) as died
FROM
  `{admissions_table}`
  INNER JOIN `{diagnoses_icd_table}`
  USING (HADM_ID)
GROUP BY HADM_ID

Optional aside: picking the regularization penalty $(\lambda)$ with Bayes' Theorem

From the frequentist point of view, $l_2$ regularized regression minimizes the negative log-likelihood of a model with an added penalty term: $\lambda \| w \|^2$. This penalty term reflects our desire for the model to be as simple as possible, and it removes the degeneracies caused by collinear input variables.

$\lambda$ is called l2_reg in BigQuery ML model options. You're given the freedom to set it to anything you want. In general, larger values of lambda encourage the model to give simpler explanations¹, and smaller values give the model more freedom to match the observed data. So what should you set $\lambda$ (a.k.a l2_reg) to?

A short calculation (see e.g. chapters 4.3.2 and 4.5.1 of Pattern Recognition and Machine Learning) shows that $l_2$ penalized logistic regression is equivalent to Bayesian logistic regression with the pior $ \omega \sim \mathcal{N}(0, \sigma^2 = \frac{1}{2 \lambda})$.

Later on in this tutorial, we'll run an $l_1$ regularized regression, which means the penalty term is $\lambda \| \omega \|$. The same reasoning applies except the corresponding prior is $w \sim \text{Laplace}(0, b = \frac{1}{\lambda})$.

This Bayesian perspective gives meaning to the value of $\lambda$. It reflects our prior uncertainty towards the strength of the relationship that we're modeling.

Since BQML automatically standardizes and one-hot encodes its inputs, we can use this interpretation to give some generic advice on choosing $\lambda$. If you don't have any special information, then any value of $\lambda$ around $1$ is reasonable, and reflects that even a perfect correlation between the input and the output is not too surprising.

As long as you choose $\lambda$ to be much less than your sample size, its exact value should not influence your results very much. And even very small values of $\lambda$ can remedy problems due to collinear inputs.


¹ Although regularization helps with overfitting, it does not completely solve it, and due care should still be taken not to select too many inputs for too little data.

Plotting the predictions

We can inspect the weights that our model learned using the ML.WEIGHTS statement. The positive weight that we see for number_of_diagnoses is our first evidence that case complexity is associated with mortality.


In [0]:
%%bigquery simple_model_weights
SELECT * FROM ML.WEIGHTS(MODEL `mimic_models.complexity_mortality_{suffix}`)

By default the weights are automatically translated to their unstandardized forms. Meaning that we don't have to standardize our inputs before multiplying them with the weights to obtain predictions. You can see the standardized weights with ML.WEIGHTS(MODEL ..., STRUCT(true AS standardize)), which can be helpful for answering questions about the relative importance of different variables, regardless of their scale.

We can use the unstandardized weights to make a python function that returns the predicted probability of mortality given an ICU admission with a certain number of diagnoses

def predict(number_of_diagnoses):
  return scipy.special.expit(
    simple_model_weights.weight[0] * number_of_diagnoses
    + simple_model_weights.weight[1])

but it's often faster and easier to make predictions with the ML.PREDICT statement.

We'd like to create a plot showing our model's predictions and the underlying data. We can use ML.PREDICT to get the data to draw the prediction line, and copy-paste the query we fed into CREATE MODEL to get the data points.


In [0]:
params = {'max_prediction': hist_df.n_diagnoses.max()}

In [0]:
%%bigquery line_df --params $params
SELECT * FROM
ML.PREDICT(MODEL `mimic_models.complexity_mortality_{suffix}`, (
  SELECT * FROM
  UNNEST(GENERATE_ARRAY(1, @max_prediction)) AS number_of_diagnoses
))

In [0]:
%%bigquery scatter_df
SELECT
  COUNT(*) AS num_diag,
  MAX(HOSPITAL_EXPIRE_FLAG) as died
FROM
  `{admissions_table}` AS adm
  INNER JOIN `{diagnoses_icd_table}` AS diag
  USING (HADM_ID)
GROUP BY HADM_ID

In [0]:
sns.regplot(
    x='num_diag',
    y='died',
    data=scatter_df,
    fit_reg=False,
    x_bins=np.arange(1,
                     scatter_df.num_diag.max() + 1))
plt.plot(line_df.number_of_diagnoses,
         line_df.predicted_died_probs.apply(lambda x: x[0]['prob']))
plt.xlabel('Case complexity (number of diagnoses)')
plt.ylabel('Probability of death during admission')

Qualitatively, our model fits the data quite well, and the trend is pretty clear. We might be tempted to say we've proven that increasing case complexity increases the probability of death during an admission to the ICU. While we've provided some evidence of this, we haven't proven it yet.

The biggest problem is we don't know if case complexity is causing the increase in deaths, or if is merely correlated with some other variables that affect the probability of death more directly.

Adding a confounding variable

Patient age is a likely candidate for a confounding variable that could be mediating the relationship between complexity and mortality. Patients generally accrue diagnoses as they age¹ and approach their life expectancy. By adding the patient's age to our model, we can see how much of the relationship between case complexity and mortality is explained the patient's age.


¹ Using the CORR standard SQL function, you can calculate that the Pearson correlation coeffiecient between age and number of diagnoses is $0.37$


In [0]:
%%bigquery
CREATE OR REPLACE MODEL `mimic_models.complexity_age_mortality_{suffix}`
OPTIONS(model_type='logistic_reg', l2_reg=2, input_label_cols=["died"])
AS
SELECT
  # MIMIC3 sets all ages over 89 to 300 to avoid the possibility of
  # identification.
  IF(DATETIME_DIFF(ADMITTIME, DOB, DAY)/365.25 < 200,
     DATETIME_DIFF(ADMITTIME, DOB, DAY)/365.25,
     # The life expectancy of a 90 year old is approximately 5 years according
     # to actuarial tables. So we'll use 95 as the mean age of 90+'s
     95) AS age,
  num_diag,
  died
FROM
  (SELECT
    COUNT(*) AS num_diag,
    MAX(HOSPITAL_EXPIRE_FLAG) as died,
    ANY_VALUE(ADMITTIME) as ADMITTIME,
    SUBJECT_ID
  FROM
    `{admissions_table}` AS adm
    JOIN `{diagnoses_icd_table}` AS diag
  USING (HADM_ID, SUBJECT_ID)
  GROUP BY HADM_ID, SUBJECT_ID
  )
  JOIN `{patients_table}` AS patients
  USING (SUBJECT_ID)

When we investigate the weights for this model, we see the weight associated with the number of diagnoses is only slightly smaller now. This tells us that some of the effect we saw in the univariate model was due to the confounding influence of age, but most of it wasn't.


In [0]:
%%bigquery
SELECT * FROM ML.WEIGHTS(MODEL `mimic_models.complexity_age_mortality_{suffix}`)

Another way to understand this relationship is to compare the effectiveness of the model with and without age as an input. This answers the question: given the number of diagnoses that a patient has received, how much extra information does their age give us? To be thorough, we could also include a model with just the patient's age. You can add a couple of code cells to this notebook and do this as an exercise if you're curious.

Plotting ROC and precision-recall curves

One way to compare the effectiveness of binary classification models is with ROC curves or a precision-recall curves.

Since ROC curves tend to appear overly optimistic when the data has a significant class imbalance, we're going to favour precision-recall curves in this tutorial. Precision-Recall curves plot the recall (which measures the model's performance on the positive samples)

$$ \text{Recall} = \frac{\text{True Positives}}{\text{True Positives} + \text{False Negatives}} $$

against the precision (which measures the model's performance on the samples it classified as positive examples)

$$ \text{Precision} = \frac{\text{True Positives}}{\text{True Positives} + \text{False Positives}} $$

as the decision threshold ranges from $0$ (predict no one dies) to $1$ (predict everyone dies)¹.

To make these plots, we're going to use the ML.ROC_CURVE BigQuery ML statement. ML.ROC_CURVE returns the data you need to draw both ROC and precision-recall curves with your graphing library of choice.

ML.ROC_CURVE defaults to using data from the evaluation dataset. If it operated on the training dataset, it would be difficult to distinguish overfitting from excellent performance. If you have your own validation dataset, you can provide it as an optional second argument.


¹ BigQuery ML uses the convention that the threshold is between $0$ and $1$, rather than the logit of this value.


In [0]:
%%bigquery comp_roc
SELECT * FROM ML.ROC_CURVE(MODEL `mimic_models.complexity_mortality_{suffix}`)

In [0]:
%%bigquery comp_age_roc
SELECT * FROM
ML.ROC_CURVE(MODEL `mimic_models.complexity_age_mortality_{suffix}`)

In [0]:
def set_precision(df):
  df['precision'] = df.true_positives / (df.true_positives + df.false_positives)


def plot_precision_recall(df, label=None):
  # manually add the threshold = -∞ point
  df = df[df.true_positives != 0]
  recall = [0] + list(df.recall)
  precision = [1] + list(df.precision)
  # x=recall, y=precision line chart
  plt.plot(recall, precision, label=label)

In [0]:
set_precision(comp_roc)
set_precision(comp_age_roc)
plot_precision_recall(comp_age_roc, label='bivariate (age) model')
plot_precision_recall(comp_roc, label='univariate model')
plt.plot(
    np.linspace(0, 1, 2), [comp_roc.precision.min()] * 2,
    label='null model',
    linestyle='--')
plt.legend()
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.xlabel(r'Recall $\left(\frac{T_p}{T_p + F_n} \right)$')
plt.ylabel(r'Precision $\left(\frac{T_p}{T_p + F_p} \right)$')

We see that:

  • Both these models are significantly better than the zero variable model, implying that case complexity has a significant impact on patient mortality.
  • Adding the patient's age only marginally improves the model, implying that the impact of case complexity is not mediated through age.

Of course, neither of these models is very good when it comes to making predictions. For our last set of models, we'll try more earnestly to predict patient mortality

More complex models

One of the main attractions of BigQuery ML is its ability to scale to high dimensional models with up to millions of variables. Our dataset isn't nearly large enough to train this many variables without severe overfitting, be we can still abide training models with hundreds of variables.

Our strategy will use the $m$ most frequent diagnoses, and a handful of other likely relevant variables as the inputs to our model. Namely, we'll use:

  • ADMISSION_TYPE: reflects the reason for, and seriousness of the admission
    • urgent
    • emergency
    • newborn
    • elective
  • INSURANCE: reflects the patients socioeconomic status, a well-known covariate with patient outcomes
    • Self Pay
    • Medicare
    • Private
    • Medicaid
    • Government
  • GENDER: accounts for both social and physiological differences across genders
  • AGE: accounts for both social and physiological differences across ages
  • number of diagnoses: our stand-in for case complexity

in addition to the top $m$ diagnoses. We'll compare models with $m \in \left\{8, 16, 32, 64, 128, 256, 512 \right\}$ to determine the most sensible value of $m$.

This will give us valuable information regarding our original question: whether case complexity increases the probability of ICU mortality. We wonder if the number of diagnoses increases patient risk only because it increases the chances of one of their many diagnoses being lethal, or if these is an interactive effect¹. We'll be able to test this by determining whether $\omega_{n_{\text{diagnoses}}}$ goes to $0$ as we increase $m$.

We'll also get some interesting information on the relative lethality of different diagnoses, and how these compare with social determinants.


¹As in the often misattributed quote: quantity has a quality all its own, or does it?

Creating the models

We'll start by getting a list of the most frequent diagnoses


In [0]:
%%bigquery top_diagnoses
WITH top_diag AS (
  SELECT COUNT(*) AS count, ICD9_CODE FROM `{diagnoses_icd_table}`
  GROUP BY ICD9_CODE
)
SELECT top_diag.ICD9_CODE, icd_lookup.SHORT_TITLE, top_diag.count FROM
top_diag JOIN
 `{d_icd_diagnoses_table}` AS icd_lookup
USING (ICD9_CODE)
ORDER BY count DESC LIMIT 1024

which we'll use to create our models. In the CREATE MODEL SELECT statement, we create one column for each of the $m$ diagnoses and fill it with $1$ if the patient had that diagnosis and $0$ otherwise.

This time around we're using l1_reg instead of l2_reg because we expect that some of our some of our many variables will not significantly impact the outcome, and we would prefer a sparse model if possible.


In [0]:
top_n_diagnoses = (8, 16, 32, 64, 128, 256, 512)

In [0]:
query_jobs = list()
for m in top_n_diagnoses:
  # The expressions for creating the new columns for each input diagnosis
  diagnosis_columns = list()
  for _, row in top_diagnoses.iloc[:m].iterrows():
    diagnosis_columns.append('MAX(IF(ICD9_CODE = "{0}", 1.0, 0.0))'
                             ' as `icd9_{0}`'.format(row.ICD9_CODE))

  query = """
  CREATE OR REPLACE MODEL `mimic_models.predict_mortality_diag_{m}_{suffix}`
  OPTIONS(model_type='logistic_reg', l1_reg=2, input_label_cols=["died"])
  AS
  WITH diagnoses AS (
    SELECT
      HADM_ID,
      COUNT(*) AS num_diag,
      {diag_cols}
    FROM `{diagnoses_icd_table}`
    WHERE ICD9_CODE IS NOT NULL
    GROUP BY HADM_ID
  )
  SELECT
    IF(DATETIME_DIFF(adm.ADMITTIME, patients.DOB, DAY)/365.25 < 200,
       DATETIME_DIFF(adm.ADMITTIME, patients.DOB, DAY)/365.25, 95) AS age,
    diagnoses.* EXCEPT (HADM_ID),
    adm.HOSPITAL_EXPIRE_FLAG as died,
    adm.ADMISSION_TYPE as adm_type,
    adm.INSURANCE as insurance,
    patients.GENDER
  FROM
    `{admissions_table}` AS adm
    LEFT JOIN `{patients_table}` AS patients USING (SUBJECT_ID)
    LEFT JOIN diagnoses USING (HADM_ID)
  """.format(
      m=m, diag_cols=',\n    '.join(diagnosis_columns), **sub_dict)
  # Run the query, and track its progress with query_jobs
  query_jobs.append(bq.query(query))

# Wait for all of the models to finish training
for j in query_jobs:
  j.exception()

Getting evaluation metrics

To obtain numerical evaluation metrics on your models, BigQuery ML provides the ML.EVALUATE statement. Just like ML.ROC_CURVE, ML.EVALUATE defaults to using the evaluation dataset that was set aside when the model was created.


In [0]:
eval_queries = list()
for m in top_n_diagnoses:
  eval_queries.append(
      'SELECT * FROM ML.EVALUATE('
      'MODEL `mimic_models.predict_mortality_diag_{}_{suffix}`)'
      .format(m, **sub_dict))
eval_query = '\nUNION ALL\n'.join(eval_queries)
bq.query(eval_query).result().to_dataframe()

And we can also plot the precision-recall curves as we did before.


In [0]:
for m in top_n_diagnoses:
  df = bq.query('SELECT * FROM ML.ROC_CURVE('
                'MODEL `mimic_models.predict_mortality_diag_{}_{suffix}`)'
                .format(m, **sub_dict)).result().to_dataframe()
  set_precision(df)
  plot_precision_recall(df, label='{} diagnoses'.format(m))

plt.plot(
    np.linspace(0, 1, 2), [df.precision.min()] * 2,
    label='null model',
    linestyle='--')
plt.legend()
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.xlabel(r'Recall $\left(\frac{T_p}{T_p + F_n} \right)$')
plt.ylabel(r'Precision $\left(\frac{T_p}{T_p + F_p} \right)$')

The model with $m = 512$ seems to be overfitting the data, while somewhere between $m = 128$ and $m = 256$ seems to be the sweet spot for model flexibility. Since we've now used the evaluation dataset to determine $m$ (albeit informally), and when to stop-early during training, dogmatic rigour would demand that we measure our model on a third (validation) dataset before we brag about its efficacy. On the other hand, there isn't a ton of flexibility in choosing between a few different values of $m$, nor in when to stop early. You can use your own judgment.

Actually, the predictive power of our model¹ isn't nearly as interesting as it's weights and what they tell us. In the next section, we'll dig into them.


¹Which could be described as approaching respectability, but still a long way away from brag worthy.

Exploring our model

Let's have a look at the weights from the $m = 128$ model.


In [0]:
%%bigquery weights_128
SELECT * FROM ML.WEIGHTS(MODEL `mimic_models.predict_mortality_diag_128_{suffix}`)
ORDER BY weight DESC

First we'll look at the weights for the numerical inputs.


In [0]:
pd.set_option('max_rows', 150)
weights_128['ICD9_CODE'] = weights_128.processed_input \
  .apply(lambda x: x[len('icd9_'):] if x.startswith('icd9_') else x)
view_df = weights_128.merge(top_diagnoses,how='left', on='ICD9_CODE') \
  .rename(columns={'ICD9_CODE': 'input'})
view_df = view_df[~pd.isnull(view_df.weight)]
view_df[['input', 'SHORT_TITLE', 'weight', 'count']]

We see have a list of diagnoses, sorted from most fatal to least fatal according to our model.

Going back to our original question, we can see that the weight for num_diag (a.k.a the number of diagnoses) has essentially gone to zero. The average diagnoses weight is also very small:


In [0]:
view_df[~pd.isnull(view_df.SHORT_TITLE)].weight.mean()

so we can conclude that given that a patient has been admitted to the ICU, the number of diagnoses they've been given does not predict their outcome beyond the linear effect of the component diagnoses.

It might be surprising that the weight for age is also very small. One explanation for this might be that DNR¹ status, and falls are among the highest weighted diagnoses. These diagnoses are associated with advanced age² ³ and there is literature³ to support that DNR status mediates the effect of age on survival. One thing we couldn't find much data on was the relationship between age and palliative treatment. This could be a good subject for a datathon team to tackle.


¹Do not resuscitate

²Article: Age-Related Changes in Physical Fall Risk Factors: Results from a 3 Year Follow-up of Community Dwelling Older Adults in Tasmania, Australia

³Article: Do Not Resuscitate (DNR) Status, Not Age, Affects Outcomes after Injury: An Evaluation of 15,227 Consecutive Trauma Patients

Now let's look at the weights for the categorical variables.


In [0]:
for _, row in weights_128[pd.isnull(weights_128.weight)].iterrows():
  print(row.processed_input)
  print(
      *sorted([tuple(x.values()) for x in row.category_weights],
              key=lambda x: x[1],
              reverse=True),
      sep='\n',
      end='\n\n')

We see that the patient's insurance has a startlingly large effect in our model.

For those of us not familiar with american medical insurance terminology¹:

  • Self pay: the patient pays out-of-pocket for their medical care as they require it
  • Medicare: a government program for people with low incomes
  • Private: insurance that is usually paid for by the patient's employer
  • Medicaid: a government program for people who have a disability or are over 65 years old
  • Government: insurance granted by the government excluding medicare and medicaid. This includes government employees and veterans.

The impact of socioeconomic status on health is on clear display here. The difference between the weights for medicare and private insurance is $0.25$, which is similar to the weight for atrial fibrillation.

The outlook for patient's paying out of pocket is also grim, and may reflect an avoidance of hospital care for financial reasons in addition to other socioeconomic factors.

The weights for the admission type seem to reflect common sense, as do the weights for gender given that females have a longer life expectancy than males.


¹ See https://en.wikipedia.org/wiki/Health_insurance_in_the_United_States

² There are thousands of articles on this, see e.g. Article: Socioeconomic Disparities in Health in the United States: What the Patterns Tell Us

³ See https://en.wikipedia.org/wiki/List_of_countries_by_life_expectancy

Conclusion

We've found evidence that case complexity increases the risk of an ICU admission, but only through the cumulative effects of the component diagnoses. That's not to say that these nonlinear interactions aren't very powerful in certain cases¹, but that this seems to be the exception rather than the rule.

We were able to obtain these results entirely from within BigQuery, with minimal modifications to standard SQL statements, only resorting to python for visualization.


¹ That is, between certain combinations or cliques of diagnoses.