In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" }
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
|
|
|
In [0]:
import collections
import tensorflow as tf
tf.compat.v2.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
There are three important concepts associated with TensorFlow Distributions shapes:
[]
. For a 5-dimensional MultivariateNormal, the event shape is [5]
.The event shape and the batch shape are properties of a Distribution
object, whereas the sample shape is associated with a specific call to sample
or log_prob
.
This notebook's purpose is to illustrate these concepts through examples, so if this isn't immediately obvious, don't worry!
For another conceptual overview of these concepts, see this blog post.
This entire notebook is written using TensorFlow Eager. None of the concepts presented rely on Eager, although with Eager, distribution batch and event shapes are evaluated (and therefore known) when the Distribution
object is created in Python, whereas in graph (non-Eager mode), it is possible to define distributions whose event and batch shapes are undetermined until the graph is run.
In [0]:
def describe_distributions(distributions):
print('\n'.join([str(d) for d in distributions]))
In this section we'll explore scalar distributions: distributions with an event shape of []
. A typical example is the Poisson distribution, specified by a rate
:
In [5]:
poisson_distributions = [
tfd.Poisson(rate=1., name='One Poisson Scalar Batch'),
tfd.Poisson(rate=[1., 10., 100.], name='Three Poissons'),
tfd.Poisson(rate=[[1., 10., 100.,], [2., 20., 200.]],
name='Two-by-Three Poissons'),
tfd.Poisson(rate=[1.], name='One Poisson Vector Batch'),
tfd.Poisson(rate=[[1.]], name='One Poisson Expanded Batch')
]
describe_distributions(poisson_distributions)
The Poisson distribution is a scalar distribution, so its event shape is always []
. If we specify more rates, these show up in the batch shape. The final pair of examples is interesting: there's only a single rate, but because that rate is embedded in a numpy array with non-empty shape, that shape becomes the batch shape.
The standard Normal distribution is also a scalar. It's event shape is []
, just like for the Poisson, but we'll play with it to see our first example of broadcasting. The Normal is specified using loc
and scale
parameters:
In [6]:
normal_distributions = [
tfd.Normal(loc=0., scale=1., name='Standard'),
tfd.Normal(loc=[0.], scale=1., name='Standard Vector Batch'),
tfd.Normal(loc=[0., 1., 2., 3.], scale=1., name='Different Locs'),
tfd.Normal(loc=[0., 1., 2., 3.], scale=[[1.], [5.]],
name='Broadcasting Scale')
]
describe_distributions(normal_distributions)
The interesting example above is the Broadcasting Scale
distribution. The loc
parameter has shape [4]
, and the scale
parameter has shape [2, 1]
. Using Numpy broadcasting rules, the batch shape is [2, 4]
. An equivalent (but less elegant and not-recommended) way to define the "Broadcasting Scale"
distribution would be:
In [7]:
describe_distributions(
[tfd.Normal(loc=[[0., 1., 2., 3], [0., 1., 2., 3.]],
scale=[[1., 1., 1., 1.], [5., 5., 5., 5.]])])
We can see why the broadcasting notation is useful, although it's also a source of headaches and bugs.
There are two main things we can do with distributions: we can sample
from them and we can compute log_prob
s. Let's explore sampling first. The basic rule is that when we sample from a distribution, the resulting Tensor has shape [sample_shape, batch_shape, event_shape]
, where batch_shape
and event_shape
are provided by the Distribution
object, and sample_shape
is provided by the call to sample
. For scalar distributions, event_shape = []
, so the Tensor returned from sample will have shape [sample_shape, batch_shape]
. Let's try it:
In [8]:
def describe_sample_tensor_shape(sample_shape, distribution):
print('Sample shape:', sample_shape)
print('Returned sample tensor shape:',
distribution.sample(sample_shape).shape)
def describe_sample_tensor_shapes(distributions, sample_shapes):
started = False
for distribution in distributions:
print(distribution)
for sample_shape in sample_shapes:
describe_sample_tensor_shape(sample_shape, distribution)
print()
sample_shapes = [1, 2, [1, 5], [3, 4, 5]]
describe_sample_tensor_shapes(poisson_distributions, sample_shapes)
In [9]:
describe_sample_tensor_shapes(normal_distributions, sample_shapes)
That's about all there is to say about sample
: returned sample tensors have shape [sample_shape, batch_shape, event_shape]
.
log_prob
For Scalar DistributionsNow let's take a look at log_prob
, which is somewhat trickier. log_prob
takes as input a (non-empty) tensor representing the location(s) at which to compute the log_prob
for the distribution. In the most straightforward case, this tensor will have a shape of the form [sample_shape, batch_shape, event_shape]
, where batch_shape
and event_shape
match the batch and event shapes of the distribution. Recall once more that for scalar distributions, event_shape = []
, so the input tensor has shape [sample_shape, batch_shape]
In this case, we get back a tensor of shape [sample_shape, batch_shape]
:
In [10]:
three_poissons = tfd.Poisson(rate=[1., 10., 100.], name='Three Poissons')
three_poissons
Out[10]:
In [11]:
three_poissons.log_prob([[1., 10., 100.], [100., 10., 1]]) # sample_shape is [2].
Out[11]:
In [12]:
three_poissons.log_prob([[[[1., 10., 100.], [100., 10., 1.]]]]) # sample_shape is [1, 1, 2].
Out[12]:
Note how in the first example, the input and output have shape [2, 3]
and in the second example they have shape [1, 1, 2, 3]
.
That would be all there was to say, if it weren't for broadcasting. Here are the rules once we take broadcasting into account. We describe it in full generality and note simplifications for scalar distributions:
n = len(batch_shape) + len(event_shape)
. (For scalar distributions, len(event_shape)=0
.)t
has fewer than n
dimensions, pad its shape by adding dimensions of size 1
on the left until it has exactly n
dimensions. Call the resulting tensor t'
.n
rightmost dimensions of t'
against the [batch_shape, event_shape]
of the distribution you're computing a log_prob
for. In more detail: for the dimensions where t'
already matches the distribution, do nothing, and for the dimensions where t'
has a singleton, replicate that singleton the appropriate number of times. Any other situation is an error. (For scalar distributions, we only broadcast against batch_shape
, since event_shape = []
.)log_prob
. The resulting tensor will have shape [sample_shape, batch_shape]
, where sample_shape
is defined to be any dimensions of t
or t'
to the left of the n
-rightmost dimensions: sample_shape = shape(t)[:-n]
.This might be a mess if you don't know what it means, so let's work some examples:
In [13]:
three_poissons.log_prob([10.])
Out[13]:
The tensor [10.]
(with shape [1]
) is broadcast across the batch_shape
of 3, so we evaluate all three Poissons' log probability at the value 10.
In [14]:
three_poissons.log_prob([[[1.], [10.]], [[100.], [1000.]]])
Out[14]:
In the above example, the input tensor has shape [2, 2, 1]
, while the distributions object has a batch shape of 3. So for each of the [2, 2]
sample dimensions, the single value provided gets broadcats to each of the three Poissons.
A possibly useful way to think of it: because three_poissons
has batch_shape = [2, 3]
, a call to log_prob
must take a Tensor whose last dimension is either 1 or 3; anything else is an error. (The numpy broadcasting rules treat the special case of a scalar as being totally equivalent to a Tensor of shape [1]
.)
Let's test our chops by playing with the more complex Poisson distribution with batch_shape = [2, 3]
:
In [0]:
poisson_2_by_3 = tfd.Poisson(
rate=[[1., 10., 100.,], [2., 20., 200.]],
name='Two-by-Three Poissons')
In [16]:
poisson_2_by_3.log_prob(1.)
Out[16]:
In [17]:
poisson_2_by_3.log_prob([1.]) # Exactly equivalent to above, demonstrating the scalar special case.
Out[17]:
In [18]:
poisson_2_by_3.log_prob([[1., 1., 1.], [1., 1., 1.]]) # Another way to write the same thing. No broadcasting.
Out[18]:
In [19]:
poisson_2_by_3.log_prob([[1., 10., 100.]]) # Input is [1, 3] broadcast to [2, 3].
Out[19]:
In [20]:
poisson_2_by_3.log_prob([[1., 10., 100.], [1., 10., 100.]]) # Equivalent to above. No broadcasting.
Out[20]:
In [21]:
poisson_2_by_3.log_prob([[1., 1., 1.], [2., 2., 2.]]) # No broadcasting.
Out[21]:
In [22]:
poisson_2_by_3.log_prob([[1.], [2.]]) # Equivalent to above. Input shape [2, 1] broadcast to [2, 3].
Out[22]:
The above examples involved broadcasting over the batch, but the sample shape was empty. Suppose we have a collection of values, and we want to get the log probability of each value at each point in the batch. We could do it manually:
In [23]:
poisson_2_by_3.log_prob([[[1., 1., 1.], [1., 1., 1.]], [[2., 2., 2.], [2., 2., 2.]]]) # Input shape [2, 2, 3].
Out[23]:
Or we could let broadcasting handle the last batch dimension:
In [24]:
poisson_2_by_3.log_prob([[[1.], [1.]], [[2.], [2.]]]) # Input shape [2, 2, 1].
Out[24]:
We can also (perhaps somewhat less naturally) let broadcasting handle just the first batch dimension:
In [25]:
poisson_2_by_3.log_prob([[[1., 1., 1.]], [[2., 2., 2.]]]) # Input shape [2, 1, 3].
Out[25]:
Or we could let broadcasting handle both batch dimensions:
In [26]:
poisson_2_by_3.log_prob([[[1.]], [[2.]]]) # Input shape [2, 1, 1].
Out[26]:
The above worked fine when we had only two values we wanted, but suppose we had a long list of values we wanted to evaluate at every batch point. For that, the following notation, which adds extra dimensions of size 1 to the right side of the shape, is extremely useful:
In [27]:
poisson_2_by_3.log_prob(tf.constant([1., 2.])[..., tf.newaxis, tf.newaxis])
Out[27]:
This is an instance of strided slice notation, which is worth knowing.
Going back to three_poissons
for completeness, the same example looks like:
In [28]:
three_poissons.log_prob([[1.], [10.], [50.], [100.]])
Out[28]:
In [29]:
three_poissons.log_prob(tf.constant([1., 10., 50., 100.])[..., tf.newaxis]) # Equivalent to above.
Out[29]:
In [30]:
multinomial_distributions = [
# Multinomial is a vector-valued distribution: if we have k classes,
# an individual sample from the distribution has k values in it, so the
# event_shape is `[k]`.
tfd.Multinomial(total_count=100., probs=[.5, .4, .1],
name='One Multinomial'),
tfd.Multinomial(total_count=[100., 1000.], probs=[.5, .4, .1],
name='Two Multinomials Same Probs'),
tfd.Multinomial(total_count=100., probs=[[.5, .4, .1], [.1, .2, .7]],
name='Two Multinomials Same Counts'),
tfd.Multinomial(total_count=[100., 1000.],
probs=[[.5, .4, .1], [.1, .2, .7]],
name='Two Multinomials Different Everything')
]
describe_distributions(multinomial_distributions)
Note how in the last three examples, the batch_shape is always [2]
, but we can use broadcasting to either have a shared total_count
or a shared probs
(or neither), because under the hood they are broadcast to have the same shape.
Sampling is straightforward, given what we know already:
In [31]:
describe_sample_tensor_shapes(multinomial_distributions, sample_shapes)
Computing log probabilities is equally straightforward. Let's work an example with diagonal Multivariate Normal distributions. (Multinomials are not very broadcast friendly, since the constraints on the counts and probabilities mean broadcasting will often produce inadmissible values.) We'll use a batch of 2 3-dimensional distributions with the same mean but different scales (standard deviations):
In [32]:
two_multivariate_normals = tfd.MultivariateNormalDiag(loc=[1., 2., 3.], scale_identity_multiplier=[1., 2.])
two_multivariate_normals
Out[32]:
(Note that although we used distributions where the scales were multiples of the identity, this is not a restriction on; we could pass scale
instead of scale_identity_multiplier
.)
Now let's evaluate the log probability of each batch point at its mean and at a shifted mean:
In [33]:
two_multivariate_normals.log_prob([[[1., 2., 3.]], [[3., 4., 5.]]]) # Input has shape [2,1,3].
Out[33]:
Exactly equivalently, we can use https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/strided-slice to insert an extra shape=1 dimension in the middle of a constant:
In [34]:
two_multivariate_normals.log_prob(
tf.constant([[1., 2., 3.], [3., 4., 5.]])[:, tf.newaxis, :]) # Equivalent to above.
Out[34]:
On the other hand, if we don't insert the extra dimension, we pass [1., 2., 3.]
to the first batch point and [3., 4., 5.]
to the second:
In [35]:
two_multivariate_normals.log_prob(tf.constant([[1., 2., 3.], [3., 4., 5.]]))
Out[35]:
In [36]:
six_way_multinomial = tfd.Multinomial(total_count=1000., probs=[.3, .25, .2, .15, .08, .02])
six_way_multinomial
Out[36]:
We created a multinomial with an event shape of [6]
. The Reshape Bijector allows us to treat this as a distribution with an event shape of [2, 3]
.
A Bijector
represents a differentiable, one-to-one function on an open subset of ${\mathbb R}^n$. Bijectors
are used in conjunction with TransformedDistribution
, which models a distribution $p(y)$ in terms of a base distribution $p(x)$ and a Bijector
that represents $Y = g(X)$.
Let's see it in action:
In [37]:
transformed_multinomial = tfd.TransformedDistribution(
distribution=six_way_multinomial,
bijector=tfb.Reshape(event_shape_out=[2, 3]))
transformed_multinomial
Out[37]:
In [38]:
six_way_multinomial.log_prob([500., 100., 100., 150., 100., 50.])
Out[38]:
In [39]:
transformed_multinomial.log_prob([[500., 100., 100.], [150., 100., 50.]])
Out[39]:
This is the only thing the Reshape
bijector can do: it cannot turn event dimensions into batch dimensions or vice-versa.
The Independent
distribution is used to treat a collection of independent, not-necessarily-identical (aka a batch of) distributions as a single distribution. More concisely, Independent
allows to convert dimensions in batch_shape
to dimensions in event_shape
. We'll illustrate by example:
In [40]:
two_by_five_bernoulli = tfd.Bernoulli(
probs=[[.05, .1, .15, .2, .25], [.3, .35, .4, .45, .5]],
name="Two By Five Bernoulli")
two_by_five_bernoulli
Out[40]:
We can think of this as two-by-five array of coins with the associated probabilities of heads. Let's evaluate the probability of a particular, arbitrary set of ones-and-zeros:
In [41]:
pattern = [[1., 0., 0., 1., 0.], [0., 0., 1., 1., 1.]]
two_by_five_bernoulli.log_prob(pattern)
Out[41]:
We can use Independent
to turn this into two different "sets of five Bernoulli's", which is useful if we want to consider a "row" of coin flips coming up in a given pattern as a single outcome:
In [42]:
two_sets_of_five = tfd.Independent(
distribution=two_by_five_bernoulli,
reinterpreted_batch_ndims=1,
name="Two Sets Of Five")
two_sets_of_five
Out[42]:
Mathematically, we're computing the log probability of each "set" of five by summing the log probabilities of the five "independent" coin flips in the set, which is where the distribution gets its name:
In [43]:
two_sets_of_five.log_prob(pattern)
Out[43]:
We can go even further and use Independent
to create a distribution where individual events are a set of two-by-five Bernoulli's:
In [44]:
one_set_of_two_by_five = tfd.Independent(
distribution=two_by_five_bernoulli, reinterpreted_batch_ndims=2,
name="One Set Of Two By Five")
one_set_of_two_by_five.log_prob(pattern)
Out[44]:
It's worth noting that from the perspective of sample
, using Independent
changes nothing:
In [45]:
describe_sample_tensor_shapes(
[two_by_five_bernoulli,
two_sets_of_five,
one_set_of_two_by_five],
[[3, 5]])
As a parting exercise for the reader, we suggest considering the differences and similarities between a vector batch of Normal
distributions and a MultivariateNormalDiag
distribution from a sampling and log probability perspective. How can we use Independent
to construct a MultivariateNormalDiag
from a batch of Normal
s? (Note that MultivariateNormalDiag
is not actually implemented this way.)