The Context Tree Switching algorithm (CTS; Veness et al., 2012) is a powerful tool for sequential prediction. While deep learning methods often achieve better accuracy in the limit of data, CTS can be trained in a single pass, needs no hyperparameter tuning, and also executes significantly faster than most current deep learning algorithms. It's also easily ensembled with other models of the same kind (see, e.g., Partition Tree Weighting) This notebook and the following (on density modelling) are designed to show basic CTS usage and highlight the algorithm's statistical efficiency.
Both Joel Veness's implementation and the original SkipCTS implementation are designed with binary compression in mind. In binary compression, each symbol is either 0 or 1. By contrast, most of modern sequential prediction uses larger alphabets: pixel values, words, ASCII characters, etc. If you're theory-leaning, you may argue that larger alphabets can always be reduced to bits. If you're more practically-minded, you know this isn't a computationally efficient choice.
This Python implementation addresses this issue by providing a model which can deal with arbitrary (categorical) alphabets, while simultaneously making the algorithm more accessible. I chose to focus here on usage, rather than delve into the math underlying the CTS magic. You'll find equations to your heart's content by reading Context Tree Switching by Veness, Ng, Hutter, and Bowling (2012).
In [1]:
import math
import string
from cts import model
Although CTS truly shines when used for binary compression, it also does an excellent job at predicting text. We will begin by training our model on Alice in Wonderland by Lewis Carroll. Part of the Canterbury Corpus, Alice in Wonderland is one of many texts which have been aggressively compressed by generations of graduate students wanting to make their mark in the field of data compression.
In [2]:
# We will predict by looking back up to 8 characters into the past.
alice_model = model.ContextualSequenceModel(context_length=8)
def train_model_alice(alice_model):
"""This method trains a given model on Alice in Wonderland."""
print ('Training model, this should take about 10-15 seconds...')
with open('alice.txt') as fp:
num_characters = 0
alice_log_probability = 0
while True:
character = fp.read(1)
if not character:
break
else:
num_characters += 1
# model.update() trains the model on one symbol at a time. It returns the log probability of the symbol.
# The negative log probability is the model's loss on this element of the sequence.
symbol_log_probability = alice_model.update(character)
alice_log_probability += symbol_log_probability
print ('Bytes read: {}'.format(num_characters))
print ('Equivalent compressed size in bytes: {:.1f}'.format(-alice_log_probability / math.log(2) / 8))
train_model_alice(alice_model)
The Equivalent compressed size above is the negative log probability (expressed in bytes) of the whole sequence. The really cool thing about compression is that arithmetic coding guarantees we can compress the original file down to within a bit of this size (modulo headers). In this case we shrink the file size to about 1/3rd of the original.
In machine learning terms, the sequence negative log probability is also the loss being minimized. Minimizing this loss is equivalent to maximizing prediction accuracy (for categorical distributions).
Context Tree Switching uses context (here, up to the eight most recent characters) to predict the next symbol. To see that this helps minimize the loss, consider the same model with a maximum context length of 0. This corresponds to predicting according to the empirical frequency of characters:
In [3]:
train_model_alice(model.ContextualSequenceModel(context_length=0))
Comparing the negative log probabilities (94102 vs 52911), it's clear that context improves prediction. Amazingly enough, the length 0 model still gets us a 43% compression rate. In this case, the model does something akin to adaptive Huffman coding, effectively rearranging the English alphabet for efficiency.
Georg R.R. Martin, of Game of Thrones fame, is well-known for releasing new books ever so slowly. We used to joke around the lab that rather than wait for him to write the next book in the series (which usually takes five or six years), we would be better off training a CTS model to produce it for us. (As of November 2016, we're still waiting on that sixth book.)
In [4]:
def sample_sequence(model, num_symbols, rejection_sampling=True):
"""Samples a sequence from the model.
Args:
num_symbols: Sequence length.
rejection_sampling: If True, only draw from observed symbols (see below).
"""
old_context = model.context.copy()
sampled_string = ''
for t in range(num_symbols):
# Sample a single symbol from the distribution. In this case we use rejection sampling to ignore characters
# we haven't seen (this leads to much nicer text).
symbol = model.sample(rejection_sampling=rejection_sampling)
# observe() moves the model's history one symbol forward, without updating the model parameters.
model.observe(symbol)
sampled_string += symbol
# Restore the model's context.
model.context = old_context
return sampled_string
Of course, the sample you're about to see was generated on the fly. It may be benign, genius, or downright offensive. I take no responsibility for the model's output.
In [5]:
print(sample_sequence(alice_model, num_symbols=400))
In generating the sequence above we cheated a little: we used rejection sampling to ignore any character which hasn't been observed in the particular context chosen by CTS. Restricting ourselves to observed symbols is a common trick used by sequential models (typically implicitly) to improve the qualitative look of models. In machine learning terms, we're overfitting to the observed alphabet.
With CTS, as long as we define the alphabet ahead of time, we can sample from the true posterior. From an online learning perspective, this is an "honest" sample: it matches the loss we would suffer from predicting this symbol.
See for yourself how sample quality degrades.
In [6]:
# Since Alice in Wonderland is written in English, we can use string.printable as our alphabet.
printable_alphabet = set(string.printable)
# Let's make sure there aren't any non-printable characters in Alice in Wonderland.
assert(alice_model.model.alphabet.issubset(printable_alphabet))
# Create a model with a pre-specified alphabet.
core_model = model.CTS(alphabet=printable_alphabet, context_length=8)
# The ContextualSequenceModel is really a wrapper around CTS. It keeps track of the history of observed
# symbols and uses the most recent as context.
alice_model_small_alphabet = model.ContextualSequenceModel(model=core_model)
train_model_alice(alice_model_small_alphabet)
print ('Sampling a sequence:')
print (sample_sequence(alice_model_small_alphabet, num_symbols=400, rejection_sampling=False))
Notice also how, when we specify the alphabet, we get better predictive performance: we compress the file in 50866 bytes, rather than 52911 (even though the samples look worse). To see how significant this 4% improvement is, consider that there's good money (in pre-deep learning era dollars) for compressing 100Mb Wikipedia down by 1 one percent. A four percent improvement on the Canterbury or Calgary corpuses is probably enough for a best paper award at the Data Compression Conference.
In the spirit of maximizing prediction accuracy, our implementation uses the Perks prior by default. The Perks prior assigns a pseudo-count of $1/A$ to each symbol (in each context), where $A$ is the size of the alphabet. We can change this, for example to the more natural Laplace prior, named after Pierre-Simon Laplace's rule of succession. The Laplace prior uses a much larger pseudo-count of 1 per symbol.
In [7]:
# Create a model with a pre-specified alphabet.
core_model = model.CTS(alphabet=printable_alphabet, context_length=8, symbol_prior='laplace')
# The ContextualSequenceModel is really a wrapper around CTS. It keeps track of the history of observed
# symbols and uses the most recent as context.
alice_model_perks_prior = model.ContextualSequenceModel(model=core_model)
train_model_alice(alice_model_perks_prior)
Wowza! That's pretty bad. Better not get your priors wrong.
We conclude this section with a little fun experiment: sampling from the CTS prior over sequences. This shows one of the greatest strengths of the CTS algorithm: it adapts very quickly to new data. We'll sample a sequence as before, except now we properly update the model with each sampled symbol.
Oh, and also, we won't train the model ahead of time.
In [8]:
# Let's use the letters subset to make the whole thing more legible.
sample_alphabet = set(string.ascii_lowercase + ' ')
def sample_from_prior():
core_model = model.CTS(alphabet=sample_alphabet, context_length=8, symbol_prior='perks')
prior_model = model.ContextualSequenceModel(model = core_model)
sampled_string = ''
for t in range(400):
symbol = prior_model.sample(rejection_sampling=False)
sampled_string += symbol
# Note: we call update() to adjust the model parameters to the sample.
prior_model.update(symbol)
print (sampled_string)
for i in range(4):
print ('Sequence {}:'.format(i))
sample_from_prior()
print ()
'What am I seeing,' you ask? These are the kind of sequences that CTS expects to observe before it sees any data. That is, these are the most likely sequences under the CTS prior, given we may only observe ASCII letters and space. Re-run the above cell a number of times and you'll see similar results. The nice structure heavily depends on the Perks prior, which you can verify by changing to laplace or jeffreys. You can also set the prior to a floating value (pseudo-count for each symbol; try 0.1 / len(printable_alphabet)).
Sampling from the prior illustrates why the Perks prior does so much better at compressing text than Jeffreys's or Laplace's priors: the samples look more text-like (modulo alphabet permutations). Still, there's a lot of room for improvement! The divergence of these samples from true English (or French, or...) text is exactly the excess cost we suffer for compressing text with CTS.
Isn't that cool?
The next tutorial demonstrates how to use CTS for modelling images.