Setup


In [ ]:
import tensorflow as tf
from google.protobuf import text_format
import os

# Either set the path to the directory with the generated tf.SequenceExamples split into train and validation sets
# or set it to a directory in which to store synthtically generated examples.
path = '/tmp/'
output_dir = '/tmp/models/'

Synthetic data generation

This step is optional - if the model is trained and evaluated on a synthetic example rather than the output of the data generation pipeline.


In [ ]:
sequence_examples_pbtxt = """
context {
  feature: {
    key  : "Patient.gender"
    value: {
      bytes_list: {
        value: [ "male" ]
      }
    }
  }
  feature: {
    key  : "label.length_of_stay_range.class"
    value: {
      bytes_list: {
        value: [ "above_14" ]
      }
    }
  }
  feature: {
    key  : "Patient.birthDate"
    value: {
      int64_list: {
        value: [ -1167580800 ]
      }
    }
  }
  feature: {
    key  : "timestamp"
    value: {
      int64_list: {
        value: [ 1528917657 ]
      }
    }
  }
  feature: {
    key  : "sequenceLength"
    value: {
      int64_list: {
        value: [ 4 ]
      }
    }
  }
}
feature_lists {
  feature_list {
    key: "Observation.code"
    value {
      feature {
        bytes_list {
          value: "loinc:2"
        }
      }
      feature {
        bytes_list {
          value: "loinc:4"
        }
      }
      feature {
        bytes_list {
          value: "loinc:6"
        }
      }
      feature {
        bytes_list {
          value: "loinc:6"
        }
      }
    }
  }
  feature_list {
    key: "Observation.value.quantity.value"
    value {
      feature {
        float_list {
          value: 1.0
        }
      }
      feature {
        float_list {
          value: 2.0
        }
      }
      feature {
        float_list {
        }
      }
      feature {
        float_list {
        }
      }
    }
  }
  feature_list {
    key: "Observation.value.quantity.unit"
    value {
      feature {
        bytes_list {
          value: "mg/L"
        }
      }
      feature {
        bytes_list {
        }
      }
      feature {
        bytes_list {
        }
      }
      feature {
        bytes_list {
        }
      }
    }
  }
  feature_list {
    key: "eventId"
    value {
      feature {
        int64_list {
          value: 1528917644
        }
      }
      feature {
        int64_list {
          value: 1528917645
        }
      }
      feature {
        int64_list {
          value: 1528917646
        }
      }
      feature {
        int64_list {
          value: 1528917647
        }
      }
    }
  }
}
"""

In [ ]:
seqex_pb=tf.train.SequenceExample()
text_format.Parse(sequence_examples_pbtxt, seqex_pb)

In [ ]:
# Write the example into training and validation files.
for split in ['train', 'validation']:
    with tf.io.TFRecordWriter(path + split) as writer:
        for seqex in [seqex_pb] * 100:
            writer.write(seqex.SerializeToString())

Define Model

Declare Hyperparameters


In [ ]:
def create_hparams(hparams_overrides=None):
  """Creates default HParams with the option of overrides.

  Args:
    hparams_overrides: HParams overriding the otherwise provided defaults.
      Defaults to None (meaning no overrides take place). HParams specified need
      to be a referencing a subset of the defaults.

  Returns:
    Default HParams.
  """
  hparams = tf.contrib.training.HParams(
      # Sequence features are bucketed by their age at time of prediction in:
      # [time_windows[0] - time_windows[1]),
      # [time_windows[1] - time_windows[2]),
      # ...
      time_windows=[
          5 * 365 * 24 * 60 * 60,  # 5 years
          365 * 24 * 60 * 60,  # 1 year
          30 * 24 * 60 * 60,  # 1 month
          7 * 24 * 60 * 60,  # 1 week
          1 * 24 * 60 * 60,  # 1 day
          0,  # now
      ],
      batch_size=64,
      learning_rate=0.003,
      dedup=True,
      l1_regularization_strength=0.0,
      l2_regularization_strength=0.0,
      include_age=True,
      age_boundaries=[1, 5, 18, 30, 50, 70, 90],
      categorical_context_features=['Patient.gender'],
      sequence_features=[
          'Composition.section.text.div.tokenized',
          'Composition.type',
          'Condition.code',
          'Encounter.hospitalization.admitSource',
          'Encounter.reason.hcc',
          'MedicationRequest.contained.medication.code.gsn',
          'Procedure.code.cpt',
      ],
      # Number of hash buckets to map the tokens of the sequence_features into.
      sequence_bucket_sizes=[
          17000,
          16,
          3052,
          10,
          62,
          1600,
          732,
      ],
      # List of strings each of which is a ':'-separated list of feature that we
      # want to concatenate over the time dimension
      time_crossed_features=[
          '%s:%s:%s:%s' % ('Observation.code',
                           'Observation.value.quantity.value',
                           'Observation.value.quantity.unit',
                           'Observation.value.string')
      ],
      time_concat_bucket_sizes=[39571],
      context_bucket_sizes=[4])
  # Other overrides (possibly coming from vizier) are applied.
  if hparams_overrides:
    hparams = tf.training.merge_hparam(hparams, hparams_overrides)
  return hparams

Setup input function


In [ ]:
CONTEXT_KEY_PREFIX = 'c-'
SEQUENCE_KEY_PREFIX = 's-'
AGE_KEY = 'Patient.ageInYears'

LABEL_VALUES = ['less_or_equal_3', '3_7', '7_14', 'above_14']


def _example_index_to_sparse_index(example_indices, batch_size):
  """Creates a sparse index tensor from a list of example indices.

  For example, this would do the transformation:
  [0, 0, 0, 1, 3, 3] -> [[0,0], [0,1], [0,2], [1,0], [3,0], [3,1]]

  The second column of the output tensor is the running count of the occurrences
  of that example index.

  Args:
    example_indices: A sorted 1D Tensor with example indices.
    batch_size: The batch_size. Could be larger than max(example_indices) if the
      last examples of the batch do not have the feature present.
  Returns:
    The sparse index tensor.
    The maxmium length of a row in this tensor.
  """
  binned_counts = tf.bincount(example_indices, minlength=batch_size)
  max_len = tf.to_int64(tf.reduce_max(binned_counts))
  return tf.where(tf.sequence_mask(binned_counts)), max_len

def _dedup_tensor(sp_tensor):
  """Dedup values of a SparseTensor along each row.

  Args:
    sp_tensor: A 2D SparseTensor to be deduped.
  Returns:
    A deduped SparseTensor of shape [batch_size, max_len], where max_len is
    the maximum number of unique values for a row in the Tensor.
  """
  string_batch_index = tf.as_string(sp_tensor.indices[:, 0])

  # tf.unique only works on 1D tensors. To avoid deduping across examples,
  # prepend each feature value with the example index. This requires casting
  # to and from strings for non-string features.
  original_dtype = sp_tensor.values.dtype
  string_values = (
      sp_tensor.values
      if original_dtype == tf.string else tf.as_string(sp_tensor.values))
  index_and_value = tf.string_join([string_batch_index, string_values],
                                   separator='|')
  unique_index_and_value, _ = tf.unique(index_and_value)

  # split is a shape [tf.size(values), 2] tensor. The first column contains
  # indices and the second column contains the feature value (we assume no
  # feature contains | so we get exactly 2 values from the string split).
  split = tf.string_split(unique_index_and_value, delimiter='|')
  split = tf.reshape(split.values, [-1, 2])
  string_indices = split[:, 0]
  values = split[:, 1]

  indices = tf.reshape(
      tf.string_to_number(string_indices, out_type=tf.int32), [-1])
  if original_dtype != tf.string:
    values = tf.string_to_number(values, out_type=original_dtype)
  values = tf.reshape(values, [-1])
  # Convert example indices into SparseTensor indices, e.g.
  # [0, 0, 0, 1, 3, 3] -> [[0,0], [0,1], [0,2], [1,0], [3,0], [3,1]]
  batch_size = tf.to_int32(sp_tensor.dense_shape[0])
  new_indices, max_len = _example_index_to_sparse_index(indices, batch_size)
  return tf.SparseTensor(
      indices=tf.to_int64(new_indices),
      values=values,
      dense_shape=[tf.to_int64(batch_size), max_len])

def get_input_fn(mode,
                 input_pattern,
                 dedup,
                 time_windows,
                 include_age,
                 categorical_context_features,
                 sequence_features,
                 time_crossed_features,
                 batch_size,
                 shuffle=True):
  """Creates an input function to an estimator.

  Args:
    mode: The execution mode, as defined in tf.estimator.ModeKeys.
    input_pattern: Input data pattern in TFRecord format containing
      tf.SequenceExamples.
    dedup: Whether to remove duplicate values.
    time_windows: List of time windows - we bucket all sequence features by
      their age into buckets [time_windows[i], time_windows[i+1]).
    include_age: Whether to include the age_in_years as a feature.
    categorical_context_features: List of string context features that are valid
      keys in the tf.SequenceExample.
    sequence_features: List of sequence features (strings) that are valid keys
      in the tf.SequenceExample.
    time_crossed_features: List of list of sequence features (strings) that
      should be crossed at each step along the time dimension.
    batch_size: The size of the batch when reading in data.
    shuffle: Whether to shuffle the examples.

  Returns:
    A function that returns a dictionary of features and the target labels.
  """

  def input_fn():
    """Supplies input to our model.

    This function supplies input to our model, where this input is a
    function of the mode. For example, we supply different data if
    we're performing training versus evaluation.

    Returns:
      A tuple consisting of 1) a dictionary of tensors whose keys are
      the feature names, and 2) a tensor of target labels if the mode
      is not INFER (and None, otherwise).
    """

    sequence_features_config = dict()
    for feature in sequence_features:
      dtype = tf.string
      if feature == 'Observation.value.quantity.value':
        dtype = tf.float32
      sequence_features_config[feature] = tf.VarLenFeature(dtype)

    sequence_features_config['eventId'] = tf.FixedLenSequenceFeature(
        [], tf.int64, allow_missing=False)
    for cross in time_crossed_features:
      for feature in cross:
        dtype = tf.string
        if feature == 'Observation.value.quantity.value':
          dtype = tf.float32
        sequence_features_config[feature] = tf.VarLenFeature(dtype)
    context_features_config = dict()
    if include_age:
      context_features_config['timestamp'] = tf.FixedLenFeature(
          [], tf.int64, default_value=-1)
      context_features_config['Patient.birthDate'] = tf.FixedLenFeature(
          [], tf.int64, default_value=-1)
    context_features_config['sequenceLength'] = tf.FixedLenFeature(
        [], tf.int64, default_value=-1)

    for context_feature in categorical_context_features:
      context_features_config[context_feature] = tf.VarLenFeature(tf.string)
    if mode != tf.estimator.ModeKeys.PREDICT:
      context_features_config['label.length_of_stay_range.class'] = (
          tf.FixedLenFeature([], tf.string, default_value='MISSING'))

    is_training = mode == tf.estimator.ModeKeys.TRAIN
    num_epochs = None if is_training else 1

    with tf.name_scope('read_batch'):
      file_names = [input_pattern]
      files = tf.data.Dataset.list_files(file_names)
      if shuffle:
        files = files.shuffle(buffer_size=len(file_names))
      dataset = (files
                 .apply(tf.data.parallel_interleave(
                     tf.data.TFRecordDataset, cycle_length=10))
                 .repeat(num_epochs))
      if shuffle:
        dataset = dataset.shuffle(buffer_size=100)
      dataset = dataset.batch(batch_size)

      def _parse_fn(serialized_examples):
        context, sequence, _ = tf.io.parse_sequence_example(
            serialized_examples,
            context_features=context_features_config,
            sequence_features=sequence_features_config,
            name='parse_sequence_example')
        return context, sequence

      dataset = dataset.map(_parse_fn, num_parallel_calls=8)

      def _process(context, sequence):
        """Supplies input to our model.

        This function supplies input to our model after parsing.

        Args:
          context: The dictionary from key to (Sparse)Tensors with context
            features
          sequence: The dictionary from key to (Sparse)Tensors with sequence
            features

        Returns:
          A tuple consisting of 1) a dictionary of tensors whose keys are
          the feature names, and 2) a tensor of target labels if the mode
          is not INFER (and None, otherwise).
        """
        # Combine into a single dictionary.
        feature_map = {}
        # Add age if requested.
        if include_age:
          age_in_seconds = (
              context['timestamp'] -
              context.pop('Patient.birthDate'))
          age_in_years = tf.to_float(age_in_seconds) / (60 * 60 * 24 * 365.0)
          feature_map[CONTEXT_KEY_PREFIX + AGE_KEY] = age_in_years

        sequence_length = context.pop('sequenceLength')
        # Cross the requested features.
        for cross in time_crossed_features:
          # The features may be missing at different rates - we take the union
          # of the indices supplying defaults.
          extended_features = dict()
          dense_shape = tf.concat(
              [[tf.shape(sequence_length)[0]], [tf.reduce_max(sequence_length)],
               tf.constant([1], dtype=tf.int64)],
              axis=0)
          for i, feature in enumerate(cross):
            sp_tensor = sequence[feature]
            additional_indices = []
            covered_indices = sp_tensor.indices
            for j, other_feature in enumerate(cross):
              if i != j:
                additional_indices.append(
                    tf.sets.set_difference(
                        tf.sparse_reorder(
                            tf.SparseTensor(
                                indices=sequence[other_feature].indices,
                                values=tf.zeros([
                                    tf.shape(sequence[other_feature].indices)[0]
                                ],
                                                dtype=tf.int32),
                                dense_shape=dense_shape)),
                        tf.sparse_reorder(
                            tf.SparseTensor(
                                indices=covered_indices,
                                values=tf.zeros([tf.shape(covered_indices)[0]],
                                                dtype=tf.int32),
                                dense_shape=dense_shape))).indices)
                covered_indices = tf.concat(
                    [sp_tensor.indices] + additional_indices, axis=0)

            additional_indices = tf.concat(additional_indices, axis=0)

            # Supply defaults for all other indices.
            default = tf.tile(
                tf.constant(['n/a']),
                multiples=[tf.shape(additional_indices)[0]])

            string_value = (
                tf.as_string(sp_tensor.values)
                if sp_tensor.values.dtype != tf.string else sp_tensor.values)

            extended_features[feature] = tf.sparse_reorder(
                tf.SparseTensor(
                    indices=tf.concat([sp_tensor.indices, additional_indices],
                                      axis=0),
                    values=tf.concat([string_value, default], axis=0),
                    dense_shape=dense_shape))

          new_values = tf.string_join(
              [extended_features[f].values for f in cross], separator='-')
          crossed_sp_tensor = tf.sparse_reorder(
              tf.SparseTensor(
                  indices=extended_features[cross[0]].indices,
                  values=new_values,
                  dense_shape=extended_features[cross[0]].dense_shape))
          sequence['_'.join(cross)] = crossed_sp_tensor
        # Remove unwanted features that are used in the cross but should not be
        # considered outside the cross.
        for cross in time_crossed_features:
          for feature in cross:
            if feature not in sequence_features and feature in sequence:
              del sequence[feature]

        # Flatten sparse tensor to compute event age. This dense tensor also
        # contains padded values. These will not be used when gathering elements
        # from the dense tensor since each sparse feature won't have a value
        # defined for the padding.
        padded_event_age = (
            # Broadcast current time along sequence dimension.
            tf.expand_dims(context.pop('timestamp'), 1)
            # Subtract time of events.
            - sequence.pop('eventId'))

        for i in range(len(time_windows) - 1):
          max_age = time_windows[i]
          min_age = time_windows[i+1]
          padded_in_time_window = tf.logical_and(padded_event_age <= max_age,
                                                 padded_event_age > min_age)

          for k, v in sequence.items():
            # For each sparse feature entry, look up whether it is in the time
            # window or not.
            in_time_window = tf.gather_nd(padded_in_time_window,
                                          v.indices[:, 0:2])
            v = tf.sparse_retain(v, in_time_window)
            sp_tensor = tf.sparse_reshape(v, [v.dense_shape[0], -1])
            if dedup:
              sp_tensor = _dedup_tensor(sp_tensor)

            feature_map[SEQUENCE_KEY_PREFIX + k +
                        '-til-%d' %min_age] = sp_tensor

        for k, v in context.items():
          feature_map[CONTEXT_KEY_PREFIX + k] = v
        return feature_map

      feature_map = (dataset
                     # Parallelize the input processing and put it behind a
                     # queue to increase performance by removing it from the
                     # critical path of per-step-computation.
                     .map(_process, num_parallel_calls=8)
                     .prefetch(buffer_size=1)
                     .make_one_shot_iterator()
                     .get_next())
      label = None
      if mode != tf.estimator.ModeKeys.PREDICT:
        label = feature_map.pop(CONTEXT_KEY_PREFIX +
                                'label.length_of_stay_range.class')
      return feature_map, label
  return input_fn

Test input function


In [ ]:
tf.reset_default_graph()
train_file = path + 'train'
validation_file = path + 'validation'
hparams = create_hparams()

time_crossed_features = [
        cross.split(':') for cross in hparams.time_crossed_features if cross
    ]

map_, label_ = get_input_fn(tf.estimator.ModeKeys.TRAIN, train_file, True, hparams.time_windows,
                            hparams.include_age, hparams.categorical_context_features,
                            hparams.sequence_features, time_crossed_features, batch_size=2)()
with tf.train.MonitoredSession() as sess:
  map_['label'] = label_
  print(sess.run(map_))

Define features and model


In [ ]:
seq_features = []
seq_features_sizes = []
hparams = create_hparams()

for k, bucket_size in zip(
    hparams.sequence_features,
    hparams.sequence_bucket_sizes):
  for max_age in hparams.time_windows[1:]:
    seq_features.append(
        tf.feature_column.categorical_column_with_hash_bucket(
            SEQUENCE_KEY_PREFIX + k + '-til-' +
            str(max_age), bucket_size))
    seq_features_sizes.append(bucket_size)

categorical_context_features = [
    tf.feature_column.categorical_column_with_hash_bucket(
        CONTEXT_KEY_PREFIX + k, bucket_size)
    for k, bucket_size in zip(hparams.categorical_context_features,
                              hparams.context_bucket_sizes)
]
discretized_context_features = []
if hparams.include_age:
  discretized_context_features.append(
      tf.feature_column.bucketized_column(
          tf.feature_column.numeric_column(CONTEXT_KEY_PREFIX + AGE_KEY),
          boundaries=hparams.age_boundaries))

optimizer = tf.train.FtrlOptimizer(
      learning_rate=hparams.learning_rate,
      l1_regularization_strength=hparams.l1_regularization_strength,
      l2_regularization_strength=hparams.l2_regularization_strength)

estimator = tf.estimator.LinearClassifier(
    feature_columns=seq_features + categorical_context_features +
    discretized_context_features,
    n_classes=len(LABEL_VALUES),
    label_vocabulary=LABEL_VALUES,
    model_dir=output_dir,
    optimizer=optimizer,
    loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE)

Setup additional metrics


In [ ]:
def multiclass_metrics_fn(labels, predictions):
  """Computes precsion/recall@k metrics for each class and micro-weighted.

  Args:
    labels: A string Tensor of shape [batch_size] with the true labels
    predictions: A float Tensor of shape [batch_size, num_classes].

  Returns:
    A dictionary with metrics of precision/recall @1/2 and precision/recall per
    class.
  """

  label_ids = tf.contrib.lookup.index_table_from_tensor(
      tuple(LABEL_VALUES),
      name='class_id_lookup').lookup(labels)
  dense_labels = tf.one_hot(label_ids, len(LABEL_VALUES))

  # We convert the task to a binary one of < 7 days.
  # 'less_or_equal_3', '3_7', '7_14', 'above_14'
  binary_labels = label_ids < 2
  binary_probs = tf.reduce_sum(predictions['probabilities'][:, 0:2], axis=1)

  metrics_dict = {
      'precision_at_1':
          tf.metrics.precision_at_k(
              labels=label_ids,
              predictions=predictions['probabilities'], k=1),
      'precision_at_2':
          tf.metrics.precision_at_k(
              labels=label_ids,
              predictions=predictions['probabilities'], k=2),
      'recall_at_1':
          tf.metrics.recall_at_k(
              labels=label_ids,
              predictions=predictions['probabilities'], k=1),
      'recall_at_2':
          tf.metrics.recall_at_k(
              labels=label_ids,
              predictions=predictions['probabilities'], k=2),
      'auc_roc_at_most_7d':
          tf.metrics.auc(
              labels=binary_labels,
              predictions=binary_probs,
              curve='ROC',
              summation_method='careful_interpolation'),
      'auc_pr_at_most_7d':
          tf.metrics.auc(
              labels=binary_labels,
              predictions=binary_probs,
              curve='PR',
              summation_method='careful_interpolation'),
      'precision_at_most_7d':
          tf.metrics.precision(
              labels=binary_labels,
              predictions=binary_probs >= 0.5),
      'recall_at_most_7d':
          tf.metrics.recall(
              labels=binary_labels,
              predictions=binary_probs >= 0.5),
  }
  for i, label in enumerate(LABEL_VALUES):
    metrics_dict['precision_%s' % label] = tf.metrics.precision_at_k(
        labels=label_ids,
        predictions=predictions['probabilities'],
        k=1,
        class_id=i)
    metrics_dict['recall_%s' % label] = tf.metrics.recall_at_k(
        labels=label_ids,
        predictions=predictions['probabilities'],
        k=1,
        class_id=i)

  return metrics_dict
estimator = tf.estimator.add_metrics(estimator, multiclass_metrics_fn)

Train and Evaluate Estimator


In [ ]:
train_input_fn = get_input_fn(tf.estimator.ModeKeys.TRAIN, train_file, True, hparams.time_windows,
                            hparams.include_age, hparams.categorical_context_features,
                            hparams.sequence_features, time_crossed_features, batch_size=24)
validation_input_fn = get_input_fn(tf.estimator.ModeKeys.EVAL, validation_file, True, hparams.time_windows,
                            hparams.include_age, hparams.categorical_context_features,
                            hparams.sequence_features, time_crossed_features, batch_size=24)

In [ ]:
estimator.train(input_fn=train_input_fn, steps=100)
estimator.evaluate(input_fn=validation_input_fn, steps=40)