In [ ]:
import tensorflow as tf
from google.protobuf import text_format
import os
# Either set the path to the directory with the generated tf.SequenceExamples split into train and validation sets
# or set it to a directory in which to store synthtically generated examples.
path = '/tmp/'
output_dir = '/tmp/models/'
In [ ]:
sequence_examples_pbtxt = """
context {
feature: {
key : "Patient.gender"
value: {
bytes_list: {
value: [ "male" ]
}
}
}
feature: {
key : "label.length_of_stay_range.class"
value: {
bytes_list: {
value: [ "above_14" ]
}
}
}
feature: {
key : "Patient.birthDate"
value: {
int64_list: {
value: [ -1167580800 ]
}
}
}
feature: {
key : "timestamp"
value: {
int64_list: {
value: [ 1528917657 ]
}
}
}
feature: {
key : "sequenceLength"
value: {
int64_list: {
value: [ 4 ]
}
}
}
}
feature_lists {
feature_list {
key: "Observation.code"
value {
feature {
bytes_list {
value: "loinc:2"
}
}
feature {
bytes_list {
value: "loinc:4"
}
}
feature {
bytes_list {
value: "loinc:6"
}
}
feature {
bytes_list {
value: "loinc:6"
}
}
}
}
feature_list {
key: "Observation.value.quantity.value"
value {
feature {
float_list {
value: 1.0
}
}
feature {
float_list {
value: 2.0
}
}
feature {
float_list {
}
}
feature {
float_list {
}
}
}
}
feature_list {
key: "Observation.value.quantity.unit"
value {
feature {
bytes_list {
value: "mg/L"
}
}
feature {
bytes_list {
}
}
feature {
bytes_list {
}
}
feature {
bytes_list {
}
}
}
}
feature_list {
key: "eventId"
value {
feature {
int64_list {
value: 1528917644
}
}
feature {
int64_list {
value: 1528917645
}
}
feature {
int64_list {
value: 1528917646
}
}
feature {
int64_list {
value: 1528917647
}
}
}
}
}
"""
In [ ]:
seqex_pb=tf.train.SequenceExample()
text_format.Parse(sequence_examples_pbtxt, seqex_pb)
In [ ]:
# Write the example into training and validation files.
for split in ['train', 'validation']:
with tf.io.TFRecordWriter(path + split) as writer:
for seqex in [seqex_pb] * 100:
writer.write(seqex.SerializeToString())
In [ ]:
def create_hparams(hparams_overrides=None):
"""Creates default HParams with the option of overrides.
Args:
hparams_overrides: HParams overriding the otherwise provided defaults.
Defaults to None (meaning no overrides take place). HParams specified need
to be a referencing a subset of the defaults.
Returns:
Default HParams.
"""
hparams = tf.contrib.training.HParams(
# Sequence features are bucketed by their age at time of prediction in:
# [time_windows[0] - time_windows[1]),
# [time_windows[1] - time_windows[2]),
# ...
time_windows=[
5 * 365 * 24 * 60 * 60, # 5 years
365 * 24 * 60 * 60, # 1 year
30 * 24 * 60 * 60, # 1 month
7 * 24 * 60 * 60, # 1 week
1 * 24 * 60 * 60, # 1 day
0, # now
],
batch_size=64,
learning_rate=0.003,
dedup=True,
l1_regularization_strength=0.0,
l2_regularization_strength=0.0,
include_age=True,
age_boundaries=[1, 5, 18, 30, 50, 70, 90],
categorical_context_features=['Patient.gender'],
sequence_features=[
'Composition.section.text.div.tokenized',
'Composition.type',
'Condition.code',
'Encounter.hospitalization.admitSource',
'Encounter.reason.hcc',
'MedicationRequest.contained.medication.code.gsn',
'Procedure.code.cpt',
],
# Number of hash buckets to map the tokens of the sequence_features into.
sequence_bucket_sizes=[
17000,
16,
3052,
10,
62,
1600,
732,
],
# List of strings each of which is a ':'-separated list of feature that we
# want to concatenate over the time dimension
time_crossed_features=[
'%s:%s:%s:%s' % ('Observation.code',
'Observation.value.quantity.value',
'Observation.value.quantity.unit',
'Observation.value.string')
],
time_concat_bucket_sizes=[39571],
context_bucket_sizes=[4])
# Other overrides (possibly coming from vizier) are applied.
if hparams_overrides:
hparams = tf.training.merge_hparam(hparams, hparams_overrides)
return hparams
In [ ]:
CONTEXT_KEY_PREFIX = 'c-'
SEQUENCE_KEY_PREFIX = 's-'
AGE_KEY = 'Patient.ageInYears'
LABEL_VALUES = ['less_or_equal_3', '3_7', '7_14', 'above_14']
def _example_index_to_sparse_index(example_indices, batch_size):
"""Creates a sparse index tensor from a list of example indices.
For example, this would do the transformation:
[0, 0, 0, 1, 3, 3] -> [[0,0], [0,1], [0,2], [1,0], [3,0], [3,1]]
The second column of the output tensor is the running count of the occurrences
of that example index.
Args:
example_indices: A sorted 1D Tensor with example indices.
batch_size: The batch_size. Could be larger than max(example_indices) if the
last examples of the batch do not have the feature present.
Returns:
The sparse index tensor.
The maxmium length of a row in this tensor.
"""
binned_counts = tf.bincount(example_indices, minlength=batch_size)
max_len = tf.to_int64(tf.reduce_max(binned_counts))
return tf.where(tf.sequence_mask(binned_counts)), max_len
def _dedup_tensor(sp_tensor):
"""Dedup values of a SparseTensor along each row.
Args:
sp_tensor: A 2D SparseTensor to be deduped.
Returns:
A deduped SparseTensor of shape [batch_size, max_len], where max_len is
the maximum number of unique values for a row in the Tensor.
"""
string_batch_index = tf.as_string(sp_tensor.indices[:, 0])
# tf.unique only works on 1D tensors. To avoid deduping across examples,
# prepend each feature value with the example index. This requires casting
# to and from strings for non-string features.
original_dtype = sp_tensor.values.dtype
string_values = (
sp_tensor.values
if original_dtype == tf.string else tf.as_string(sp_tensor.values))
index_and_value = tf.string_join([string_batch_index, string_values],
separator='|')
unique_index_and_value, _ = tf.unique(index_and_value)
# split is a shape [tf.size(values), 2] tensor. The first column contains
# indices and the second column contains the feature value (we assume no
# feature contains | so we get exactly 2 values from the string split).
split = tf.string_split(unique_index_and_value, delimiter='|')
split = tf.reshape(split.values, [-1, 2])
string_indices = split[:, 0]
values = split[:, 1]
indices = tf.reshape(
tf.string_to_number(string_indices, out_type=tf.int32), [-1])
if original_dtype != tf.string:
values = tf.string_to_number(values, out_type=original_dtype)
values = tf.reshape(values, [-1])
# Convert example indices into SparseTensor indices, e.g.
# [0, 0, 0, 1, 3, 3] -> [[0,0], [0,1], [0,2], [1,0], [3,0], [3,1]]
batch_size = tf.to_int32(sp_tensor.dense_shape[0])
new_indices, max_len = _example_index_to_sparse_index(indices, batch_size)
return tf.SparseTensor(
indices=tf.to_int64(new_indices),
values=values,
dense_shape=[tf.to_int64(batch_size), max_len])
def get_input_fn(mode,
input_pattern,
dedup,
time_windows,
include_age,
categorical_context_features,
sequence_features,
time_crossed_features,
batch_size,
shuffle=True):
"""Creates an input function to an estimator.
Args:
mode: The execution mode, as defined in tf.estimator.ModeKeys.
input_pattern: Input data pattern in TFRecord format containing
tf.SequenceExamples.
dedup: Whether to remove duplicate values.
time_windows: List of time windows - we bucket all sequence features by
their age into buckets [time_windows[i], time_windows[i+1]).
include_age: Whether to include the age_in_years as a feature.
categorical_context_features: List of string context features that are valid
keys in the tf.SequenceExample.
sequence_features: List of sequence features (strings) that are valid keys
in the tf.SequenceExample.
time_crossed_features: List of list of sequence features (strings) that
should be crossed at each step along the time dimension.
batch_size: The size of the batch when reading in data.
shuffle: Whether to shuffle the examples.
Returns:
A function that returns a dictionary of features and the target labels.
"""
def input_fn():
"""Supplies input to our model.
This function supplies input to our model, where this input is a
function of the mode. For example, we supply different data if
we're performing training versus evaluation.
Returns:
A tuple consisting of 1) a dictionary of tensors whose keys are
the feature names, and 2) a tensor of target labels if the mode
is not INFER (and None, otherwise).
"""
sequence_features_config = dict()
for feature in sequence_features:
dtype = tf.string
if feature == 'Observation.value.quantity.value':
dtype = tf.float32
sequence_features_config[feature] = tf.VarLenFeature(dtype)
sequence_features_config['eventId'] = tf.FixedLenSequenceFeature(
[], tf.int64, allow_missing=False)
for cross in time_crossed_features:
for feature in cross:
dtype = tf.string
if feature == 'Observation.value.quantity.value':
dtype = tf.float32
sequence_features_config[feature] = tf.VarLenFeature(dtype)
context_features_config = dict()
if include_age:
context_features_config['timestamp'] = tf.FixedLenFeature(
[], tf.int64, default_value=-1)
context_features_config['Patient.birthDate'] = tf.FixedLenFeature(
[], tf.int64, default_value=-1)
context_features_config['sequenceLength'] = tf.FixedLenFeature(
[], tf.int64, default_value=-1)
for context_feature in categorical_context_features:
context_features_config[context_feature] = tf.VarLenFeature(tf.string)
if mode != tf.estimator.ModeKeys.PREDICT:
context_features_config['label.length_of_stay_range.class'] = (
tf.FixedLenFeature([], tf.string, default_value='MISSING'))
is_training = mode == tf.estimator.ModeKeys.TRAIN
num_epochs = None if is_training else 1
with tf.name_scope('read_batch'):
file_names = [input_pattern]
files = tf.data.Dataset.list_files(file_names)
if shuffle:
files = files.shuffle(buffer_size=len(file_names))
dataset = (files
.apply(tf.data.parallel_interleave(
tf.data.TFRecordDataset, cycle_length=10))
.repeat(num_epochs))
if shuffle:
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(batch_size)
def _parse_fn(serialized_examples):
context, sequence, _ = tf.io.parse_sequence_example(
serialized_examples,
context_features=context_features_config,
sequence_features=sequence_features_config,
name='parse_sequence_example')
return context, sequence
dataset = dataset.map(_parse_fn, num_parallel_calls=8)
def _process(context, sequence):
"""Supplies input to our model.
This function supplies input to our model after parsing.
Args:
context: The dictionary from key to (Sparse)Tensors with context
features
sequence: The dictionary from key to (Sparse)Tensors with sequence
features
Returns:
A tuple consisting of 1) a dictionary of tensors whose keys are
the feature names, and 2) a tensor of target labels if the mode
is not INFER (and None, otherwise).
"""
# Combine into a single dictionary.
feature_map = {}
# Add age if requested.
if include_age:
age_in_seconds = (
context['timestamp'] -
context.pop('Patient.birthDate'))
age_in_years = tf.to_float(age_in_seconds) / (60 * 60 * 24 * 365.0)
feature_map[CONTEXT_KEY_PREFIX + AGE_KEY] = age_in_years
sequence_length = context.pop('sequenceLength')
# Cross the requested features.
for cross in time_crossed_features:
# The features may be missing at different rates - we take the union
# of the indices supplying defaults.
extended_features = dict()
dense_shape = tf.concat(
[[tf.shape(sequence_length)[0]], [tf.reduce_max(sequence_length)],
tf.constant([1], dtype=tf.int64)],
axis=0)
for i, feature in enumerate(cross):
sp_tensor = sequence[feature]
additional_indices = []
covered_indices = sp_tensor.indices
for j, other_feature in enumerate(cross):
if i != j:
additional_indices.append(
tf.sets.set_difference(
tf.sparse_reorder(
tf.SparseTensor(
indices=sequence[other_feature].indices,
values=tf.zeros([
tf.shape(sequence[other_feature].indices)[0]
],
dtype=tf.int32),
dense_shape=dense_shape)),
tf.sparse_reorder(
tf.SparseTensor(
indices=covered_indices,
values=tf.zeros([tf.shape(covered_indices)[0]],
dtype=tf.int32),
dense_shape=dense_shape))).indices)
covered_indices = tf.concat(
[sp_tensor.indices] + additional_indices, axis=0)
additional_indices = tf.concat(additional_indices, axis=0)
# Supply defaults for all other indices.
default = tf.tile(
tf.constant(['n/a']),
multiples=[tf.shape(additional_indices)[0]])
string_value = (
tf.as_string(sp_tensor.values)
if sp_tensor.values.dtype != tf.string else sp_tensor.values)
extended_features[feature] = tf.sparse_reorder(
tf.SparseTensor(
indices=tf.concat([sp_tensor.indices, additional_indices],
axis=0),
values=tf.concat([string_value, default], axis=0),
dense_shape=dense_shape))
new_values = tf.string_join(
[extended_features[f].values for f in cross], separator='-')
crossed_sp_tensor = tf.sparse_reorder(
tf.SparseTensor(
indices=extended_features[cross[0]].indices,
values=new_values,
dense_shape=extended_features[cross[0]].dense_shape))
sequence['_'.join(cross)] = crossed_sp_tensor
# Remove unwanted features that are used in the cross but should not be
# considered outside the cross.
for cross in time_crossed_features:
for feature in cross:
if feature not in sequence_features and feature in sequence:
del sequence[feature]
# Flatten sparse tensor to compute event age. This dense tensor also
# contains padded values. These will not be used when gathering elements
# from the dense tensor since each sparse feature won't have a value
# defined for the padding.
padded_event_age = (
# Broadcast current time along sequence dimension.
tf.expand_dims(context.pop('timestamp'), 1)
# Subtract time of events.
- sequence.pop('eventId'))
for i in range(len(time_windows) - 1):
max_age = time_windows[i]
min_age = time_windows[i+1]
padded_in_time_window = tf.logical_and(padded_event_age <= max_age,
padded_event_age > min_age)
for k, v in sequence.items():
# For each sparse feature entry, look up whether it is in the time
# window or not.
in_time_window = tf.gather_nd(padded_in_time_window,
v.indices[:, 0:2])
v = tf.sparse_retain(v, in_time_window)
sp_tensor = tf.sparse_reshape(v, [v.dense_shape[0], -1])
if dedup:
sp_tensor = _dedup_tensor(sp_tensor)
feature_map[SEQUENCE_KEY_PREFIX + k +
'-til-%d' %min_age] = sp_tensor
for k, v in context.items():
feature_map[CONTEXT_KEY_PREFIX + k] = v
return feature_map
feature_map = (dataset
# Parallelize the input processing and put it behind a
# queue to increase performance by removing it from the
# critical path of per-step-computation.
.map(_process, num_parallel_calls=8)
.prefetch(buffer_size=1)
.make_one_shot_iterator()
.get_next())
label = None
if mode != tf.estimator.ModeKeys.PREDICT:
label = feature_map.pop(CONTEXT_KEY_PREFIX +
'label.length_of_stay_range.class')
return feature_map, label
return input_fn
In [ ]:
tf.reset_default_graph()
train_file = path + 'train'
validation_file = path + 'validation'
hparams = create_hparams()
time_crossed_features = [
cross.split(':') for cross in hparams.time_crossed_features if cross
]
map_, label_ = get_input_fn(tf.estimator.ModeKeys.TRAIN, train_file, True, hparams.time_windows,
hparams.include_age, hparams.categorical_context_features,
hparams.sequence_features, time_crossed_features, batch_size=2)()
with tf.train.MonitoredSession() as sess:
map_['label'] = label_
print(sess.run(map_))
In [ ]:
seq_features = []
seq_features_sizes = []
hparams = create_hparams()
for k, bucket_size in zip(
hparams.sequence_features,
hparams.sequence_bucket_sizes):
for max_age in hparams.time_windows[1:]:
seq_features.append(
tf.feature_column.categorical_column_with_hash_bucket(
SEQUENCE_KEY_PREFIX + k + '-til-' +
str(max_age), bucket_size))
seq_features_sizes.append(bucket_size)
categorical_context_features = [
tf.feature_column.categorical_column_with_hash_bucket(
CONTEXT_KEY_PREFIX + k, bucket_size)
for k, bucket_size in zip(hparams.categorical_context_features,
hparams.context_bucket_sizes)
]
discretized_context_features = []
if hparams.include_age:
discretized_context_features.append(
tf.feature_column.bucketized_column(
tf.feature_column.numeric_column(CONTEXT_KEY_PREFIX + AGE_KEY),
boundaries=hparams.age_boundaries))
optimizer = tf.train.FtrlOptimizer(
learning_rate=hparams.learning_rate,
l1_regularization_strength=hparams.l1_regularization_strength,
l2_regularization_strength=hparams.l2_regularization_strength)
estimator = tf.estimator.LinearClassifier(
feature_columns=seq_features + categorical_context_features +
discretized_context_features,
n_classes=len(LABEL_VALUES),
label_vocabulary=LABEL_VALUES,
model_dir=output_dir,
optimizer=optimizer,
loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE)
In [ ]:
def multiclass_metrics_fn(labels, predictions):
"""Computes precsion/recall@k metrics for each class and micro-weighted.
Args:
labels: A string Tensor of shape [batch_size] with the true labels
predictions: A float Tensor of shape [batch_size, num_classes].
Returns:
A dictionary with metrics of precision/recall @1/2 and precision/recall per
class.
"""
label_ids = tf.contrib.lookup.index_table_from_tensor(
tuple(LABEL_VALUES),
name='class_id_lookup').lookup(labels)
dense_labels = tf.one_hot(label_ids, len(LABEL_VALUES))
# We convert the task to a binary one of < 7 days.
# 'less_or_equal_3', '3_7', '7_14', 'above_14'
binary_labels = label_ids < 2
binary_probs = tf.reduce_sum(predictions['probabilities'][:, 0:2], axis=1)
metrics_dict = {
'precision_at_1':
tf.metrics.precision_at_k(
labels=label_ids,
predictions=predictions['probabilities'], k=1),
'precision_at_2':
tf.metrics.precision_at_k(
labels=label_ids,
predictions=predictions['probabilities'], k=2),
'recall_at_1':
tf.metrics.recall_at_k(
labels=label_ids,
predictions=predictions['probabilities'], k=1),
'recall_at_2':
tf.metrics.recall_at_k(
labels=label_ids,
predictions=predictions['probabilities'], k=2),
'auc_roc_at_most_7d':
tf.metrics.auc(
labels=binary_labels,
predictions=binary_probs,
curve='ROC',
summation_method='careful_interpolation'),
'auc_pr_at_most_7d':
tf.metrics.auc(
labels=binary_labels,
predictions=binary_probs,
curve='PR',
summation_method='careful_interpolation'),
'precision_at_most_7d':
tf.metrics.precision(
labels=binary_labels,
predictions=binary_probs >= 0.5),
'recall_at_most_7d':
tf.metrics.recall(
labels=binary_labels,
predictions=binary_probs >= 0.5),
}
for i, label in enumerate(LABEL_VALUES):
metrics_dict['precision_%s' % label] = tf.metrics.precision_at_k(
labels=label_ids,
predictions=predictions['probabilities'],
k=1,
class_id=i)
metrics_dict['recall_%s' % label] = tf.metrics.recall_at_k(
labels=label_ids,
predictions=predictions['probabilities'],
k=1,
class_id=i)
return metrics_dict
estimator = tf.estimator.add_metrics(estimator, multiclass_metrics_fn)
In [ ]:
train_input_fn = get_input_fn(tf.estimator.ModeKeys.TRAIN, train_file, True, hparams.time_windows,
hparams.include_age, hparams.categorical_context_features,
hparams.sequence_features, time_crossed_features, batch_size=24)
validation_input_fn = get_input_fn(tf.estimator.ModeKeys.EVAL, validation_file, True, hparams.time_windows,
hparams.include_age, hparams.categorical_context_features,
hparams.sequence_features, time_crossed_features, batch_size=24)
In [ ]:
estimator.train(input_fn=train_input_fn, steps=100)
estimator.evaluate(input_fn=validation_input_fn, steps=40)