Copyright 2014 by Andreas Klostermann andreas.klostermann@gmail.com
PyMC 3 is a revolutionary framework for implementing bayesian inference methods. While it is not yet as feature complete as PyMC 2 or other similar software, it already features advanced sampling methods.
Markov Chain Monte Carlo methods generate a large set of samples, from which further inference is derived. This sampling often takes the form of a random walk across the parameter space. Different sampling methods go about this random walking in very different ways, and the following example is explicitly constructed to be very hard to sample.
Depending on the particular model, this sampling process can take quite a bit of time. Therefore it is desirable to watch the progress of the sampler and preliminary results.
The code in this notebook is in a very experimental state. I used Bokeh's GIT revision "dab61e9" from September 13th, and IPython 2.1.0 with an Anaconda distribution. PyMC 3 is revision "1f2a639", installed locally.
From experience I expect bokeh's API to change rapidly, and I may or may not have the time to update these notebooks.
In [1]:
from __future__ import division
import json
import time
import numpy as np
import pandas as pd
from IPython.core import display
from IPython.html import widgets
from IPython.utils.traitlets import Unicode
import bokeh
from bokeh.plotting import *
from bokeh.protocol import serialize_json
from pymc import *
from pymc import backends
In [2]:
output_notebook()
This widget does not have a "stop" button, because this would require some more PyMC hacking. Additionally to the "replace_bokeh_data_source" there is also an "append_to_bokeh_data_source" message which allows us to update datasource by appending new data. In MCMC sampling, old samples never change or get deleted, so it is faster to just transmit the new data.
In [3]:
%%javascript
require(["widgets/js/widget"], function(WidgetManager){
var AnimationWidget = IPython.WidgetView.extend({
render: function(){
// Display as simple h1 tag
var html = $("<h1>"+this.model.get("title")+"</h1>");
this.html = html;
this.setElement(html);
this.model.on('msg:custom', $.proxy(this.handle_custom_message, this));
},
handle_custom_message: function(data){
switch (data.custom_type){
case "replace_bokeh_data_source":
var ds = Bokeh.Collections(data.ds_model).get(data.ds_id);
ds.set($.parseJSON(data.ds_json));
ds.trigger("change");
break;
case "append_to_bokeh_data_source":
var ds = Bokeh.Collections(data.ds_model).get(data.ds_id);
var changes = $.parseJSON(data.ds_json);
for (var col in changes) {
if (changes.hasOwnProperty(col)) {
var c = changes[col];
for (var i=0; i<c.length; i++)
ds.attributes.data[col].push(c[i])
}
};
ds.trigger("change");
break;
}
}
});
WidgetManager.register_widget_view('AnimationWidget', AnimationWidget)
});
The Python side of the widget implements the two message types as member functions. There is an additional twist though: I want to store the final result in the notebook file, so that it is available for the nbconvert. To achieve this, the widget keeps track of the latest state of each datasource, and after finishing the simulation, the widget's "finish" method inlines these changes as script tags into the output.
This allows the bokeh library to render the final result in nbconvert's output.
I haven't looked into other formats yet, though.
In [4]:
from IPython.utils.traitlets import Unicode
import json
class AnimationWidget(widgets.DOMWidget):
_view_name = Unicode('AnimationWidget', sync=True)
title = Unicode(sync=True)
def __init__(self, title, *args, **kwargs):
widgets.DOMWidget.__init__(self,*args, **kwargs)
self.iteration = 0
self.on_msg(self._handle_custom_msg)
self.stopped = False
self.title = title
self.send_state()
self.final_replacements = {}
def replace_bokeh_data_source(self, ds):
json = serialize_json(ds.vm_serialize())
self.send({"custom_type": "replace_bokeh_data_source",
"ds_id": ds.ref['id'], # Model Id
"ds_model":ds.ref['type'], # Collection Type
"ds_json": json
})
self.final_replacements[ds.ref['id']] = dict(ds_model=ds.ref['type'],
ds_id=ds.ref['id'],
ds_json=json)
def append_to_bokeh_data_source(self, ds, columns, n):
changes = {}
for col in columns:
changes[col] = ds.data[col][-n:]
self.send({"custom_type": "append_to_bokeh_data_source",
"ds_id": ds.ref['id'], # Model Id
"ds_model":ds.ref['type'], # Collection Type
"ds_json": serialize_json(changes)
})
self.final_replacements[ds.ref['id']] = dict(ds_model=ds.ref['type'],
ds_id=ds.ref['id'],
ds_json=None,
ds=ds)
def finish(self):
for replacement in self.final_replacements.values():
ds_model = replacement['ds_model']
ds_id = replacement['ds_id']
ds_json = replacement['ds_json']
if ds_json is None:
ds_json = serialize_json(replacement['ds'].vm_serialize())
html = """<script type="text/javascript">
$(function() {{
var ds = Bokeh.Collections('{ds_model}').get('{ds_id}');
var data = {ds_json};
ds.set(data);
ds.trigger("change");
}});
</script>
""".format(ds_model=ds_model, ds_id=ds_id, ds_json=ds_json)
display.display_html(html, raw=True)
PyMC has a mechanism which allows users to subclass the sample storage backend. The default backend is "NDArray", which just stores the samples in a NumPy array. There is another one which stores the samples in an SQLite database.
The Visualizer class subclasses the NDArray backend. The idea is to modify the graphs when new data is stored.
In [5]:
class Visualizer(backends.NDArray):
def __init__(self, name=None, model=None, vars=None, widget=None):
super(Visualizer, self).__init__(name, model, vars)
self.widget = widget
def setup(self, draws, chain):
self.last_time = time.time()
self.draws = draws
super(Visualizer, self).setup(draws,chain)
self.show()
def record(self, point):
super(Visualizer, self).record(point)
get_ipython().kernel.do_one_iteration()
The BananaVisualizer subclass creates plots from the specific model we will sample from, and updates the datasources of these plots. I am using the new gridplot method, and I am pretty sure the API is going to change very soon!
In [6]:
class BananaVisualizer(Visualizer):
def show(self):
height = 300
width_a = 500
width_b = 300
self.last_published_index = 0
figure(plot_width=width_a, plot_height=height,
x_range=bokeh.plotting_helpers.Range1d(start=0, end=self.draws),
y_range=bokeh.plotting_helpers.Range1d(start=-2,end=2),
)
plot_x=line([0],[0], title="X",)
data_sources = [r.data_source for r in plot_x.renderers if hasattr(r, 'data_source')]
self.ds_x = data_sources[0]
figure(plot_width=width_a, plot_height=height,
x_range=bokeh.plotting_helpers.Range1d(start=0, end=self.draws),
y_range=bokeh.plotting_helpers.Range1d(start=-2,end=2),
)
plot_y=line([0],[0], title="Y",)
data_sources = [r.data_source for r in plot_y.renderers if hasattr(r, 'data_source')]
self.ds_y = data_sources[0]
figure(plot_width=width_b, plot_height=height)
plot_xb=scatter([0],[0], title="X vs Y", alpha=0.1, line_width=1,
x_range=bokeh.plotting_helpers.Range1d(start=-2,end=2),
y_range=bokeh.plotting_helpers.Range1d(start=-2,end=2)
)
data_sources = [r.data_source for r in plot_xb.renderers if hasattr(r, 'data_source')]
self.ds_scatter = data_sources[0]
figure(plot_width=width_b, plot_height=height)
# Histograms are in bokeh can be implemented by using squares
plot_yb = quad(top=[0], bottom=0, left=[0], right=[0],
fill_color="#036564", line_color="#033649",
alpha=0.5, title="Histogram")
# To draw more than one glyph in one plot, we need to "hold"
hold(True)
plot_yb = quad(top=[0], bottom=0, left=[0], right=[0],
fill_color="#146564", line_color="#033649", alpha=0.5)
# Because of holding the plot, we now have two datasources
data_sources = [r.data_source for r in plot_yb.renderers if hasattr(r, 'data_source')]
# I haven't found a guarantee in Bokeh's documentation about the order of these data sources,
# But I suspect that currently, the order is predictable.
self.ds_hist_x = data_sources[0]
self.ds_hist_y = data_sources[1]
gridplot([[plot_x, plot_xb],[plot_y, plot_yb]])
show()
def record(self, point):
super(BananaVisualizer, self).record(point)
duration = time.time()-self.last_time
if duration>=0.05:
elapsed_index = self.draw_idx - self.last_published_index
self.last_published_index = self.draw_idx
self.last_time = time.time()
self.ds_x.data['x'] = np.arange(0, self.draw_idx)
self.ds_x.data['y'] = self.samples['x'][:self.draw_idx]
self.widget.append_to_bokeh_data_source(self.ds_x, ['x', 'y'], elapsed_index)
self.ds_y.data['x'] = np.arange(0, self.draw_idx)
self.ds_y.data['y'] = self.samples['y'][:self.draw_idx]
self.widget.append_to_bokeh_data_source(self.ds_y, ['x','y'], elapsed_index)
self.ds_scatter.data['x'] = self.samples['x'][:self.draw_idx]
self.ds_scatter.data['y'] = self.samples['y'][:self.draw_idx]
self.widget.append_to_bokeh_data_source(self.ds_scatter, ['x', 'y'], elapsed_index)
hist, edges = np.histogram( self.samples['x'][:self.draw_idx], bins=20)
self.ds_hist_x.data['top'] = hist
self.ds_hist_x.data['left'] = edges[:-1]
self.ds_hist_x.data['right'] = edges[1:]
self.widget.replace_bokeh_data_source(self.ds_hist_x)
hist, edges = np.histogram( self.samples['y'][:self.draw_idx], bins=20)
self.ds_hist_y.data['top'] = hist
self.ds_hist_y.data['left'] = edges[:-1]
self.ds_hist_y.data['right'] = edges[1:]
self.widget.replace_bokeh_data_source(self.ds_hist_y)
The record method also throttles the updates to a reasonable pace. This throttling could be more elaborate, for example it could take into account how much time the sampling process spends inside the visualization routines. It wouldn't be useful to have a Visualiziation which takes two seconds to compute and serialize the datasources, and a sample which could otherwise sample a thousand samples per second.
In [7]:
n = 6000
with Model() as model:
x = Normal('x', 0, 1)
y = Normal('y', 0, 1)
N = 200
d = Normal('d', x + y ** 2, 1., observed=np.zeros(N))
w = AnimationWidget("Metropolis-Hastings Sampler")
display.display(w)
backend = BananaVisualizer(widget=w)
step = Metropolis()
trace = sample(n, step, trace=backend, progressbar=False)
w.finish()
In [8]:
n = 6000
with Model() as model:
x = Normal('x', 0, 1)
y = Normal('y', 0, 1)
N = 200
d = Normal('d', x + y ** 2, 1., observed=np.zeros(N))
w = AnimationWidget("Slice Sampler")
display.display(w)
backend = BananaVisualizer(widget=w)
step = Slice()
trace = sample(n, step, trace=backend, progressbar=False)
w.finish()
In [9]:
n= 6000
with Model() as model:
x = Normal('x', 0, 1)
y = Normal('y', 0, 1)
N = 200
d = Normal('d', x + y ** 2, 1., observed=np.zeros(N))
w = AnimationWidget("Hamiltonian MC")
display.display(w)
backend = BananaVisualizer(widget=w)
start = model.test_point
h = np.ones(2) * np.diag(find_hessian(start))[0]
step = HamiltonianMC(model.vars, h, path_length=4.)
trace = sample(n, step, start, trace=backend, progressbar=False)
w.finish()
There are various useful ways to visualize MCMC methods. In some cases, autocorrelation plots, or convergence criteria in general would be very useful. It might be useful to create a more sophisticated "model runner" which can visualize arbitrary models, provide a start/stop button and a progress bar.
For models which have some geographic aspect you could use bokeh's map functionality.