In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" }
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
|
|
|
In this colab we show various examples of building learnable ("trainable") distributions. (We make no effort to explain the distributions, only to show how to build them.)
In [0]:
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import prefer_static
tfb = tfp.bijectors
tfd = tfp.distributions
tf.enable_v2_behavior()
In [0]:
event_size = 4
num_components = 3
In [0]:
learnable_mvn_scaled_identity = tfd.Independent(
tfd.Normal(
loc=tf.Variable(tf.zeros(event_size), name='loc'),
scale=tfp.util.TransformedVariable(
tf.ones([event_size, 1]),
bijector=tfb.Exp()),
name='scale'),
reinterpreted_batch_ndims=1,
name='learnable_mvn_scaled_identity')
print(learnable_mvn_scaled_identity)
print(learnable_mvn_scaled_identity.trainable_variables)
In [0]:
learnable_mvndiag = tfd.Independent(
tfd.Normal(
loc=tf.Variable(tf.zeros(event_size), name='loc'),
scale=tfp.util.TransformedVariable(
tf.ones(event_size),
bijector=tfb.Softplus()), # Use Softplus...cuz why not?
name='scale'),
reinterpreted_batch_ndims=1,
name='learnable_mvn_diag')
print(learnable_mvndiag)
print(learnable_mvndiag.trainable_variables)
In [0]:
learnable_mix_mvn_scaled_identity = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(
logits=tf.Variable(
# Changing the `1.` intializes with a geometric decay.
-tf.math.log(1.) * tf.range(num_components, dtype=tf.float32),
name='logits')),
components_distribution=tfd.Independent(
tfd.Normal(
loc=tf.Variable(
tf.random.normal([num_components, event_size]),
name='loc'),
scale=tfp.util.TransformedVariable(
10. * tf.ones([num_components, 1]),
bijector=tfb.Softplus()), # Use Softplus...cuz why not?
name='scale'),
reinterpreted_batch_ndims=1),
name='learnable_mix_mvn_scaled_identity')
print(learnable_mix_mvn_scaled_identity)
print(learnable_mix_mvn_scaled_identity.trainable_variables)
In [0]:
learnable_mix_mvndiag_first_fixed = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(
logits=tfp.util.TransformedVariable(
# Initialize logits as geometric decay.
-tf.math.log(1.5) * tf.range(num_components, dtype=tf.float32),
tfb.Pad(paddings=[[1, 0]], constant_values=0)),
name='logits'),
components_distribution=tfd.Independent(
tfd.Normal(
loc=tf.Variable(
# Use Rademacher...cuz why not?
tfp.math.random_rademacher([num_components, event_size]),
name='loc'),
scale=tfp.util.TransformedVariable(
10. * tf.ones([num_components, 1]),
bijector=tfb.Softplus()), # Use Softplus...cuz why not?
name='scale'),
reinterpreted_batch_ndims=1),
name='learnable_mix_mvndiag_first_fixed')
print(learnable_mix_mvndiag_first_fixed)
print(learnable_mix_mvndiag_first_fixed.trainable_variables)
In [0]:
learnable_mix_mvntril = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(
logits=tf.Variable(
# Changing the `1.` intializes with a geometric decay.
-tf.math.log(1.) * tf.range(num_components, dtype=tf.float32),
name='logits')),
components_distribution=tfd.MultivariateNormalTriL(
loc=tf.Variable(tf.zeros([num_components, event_size]), name='loc'),
scale_tril=tfp.util.TransformedVariable(
10. * tf.eye(event_size, batch_shape=[num_components]),
bijector=tfb.FillScaleTriL()),
name='scale_tril'),
name='learnable_mix_mvntril')
print(learnable_mix_mvntril)
print(learnable_mix_mvntril.trainable_variables)
In [0]:
# Make a bijector which pads an eye to what otherwise fills a tril.
num_tril_nonzero = lambda num_rows: num_rows * (num_rows + 1) // 2
num_tril_rows = lambda nnz: prefer_static.cast(
prefer_static.sqrt(0.25 + 2. * prefer_static.cast(nnz, tf.float32)) - 0.5,
tf.int32)
# TFP doesn't have a concat bijector, so we roll out our own.
class PadEye(tfb.Bijector):
def __init__(self, tril_fn=None):
if tril_fn is None:
tril_fn = tfb.FillScaleTriL()
self._tril_fn = getattr(tril_fn, 'inverse', tril_fn)
super(PadEye, self).__init__(
forward_min_event_ndims=2,
inverse_min_event_ndims=2,
is_constant_jacobian=True,
name='PadEye')
def _forward(self, x):
num_rows = int(num_tril_rows(tf.compat.dimension_value(x.shape[-1])))
eye = tf.eye(num_rows, batch_shape=prefer_static.shape(x)[:-2])
return tf.concat([self._tril_fn(eye)[..., tf.newaxis, :], x],
axis=prefer_static.rank(x) - 2)
def _inverse(self, y):
return y[..., 1:, :]
def _forward_log_det_jacobian(self, x):
return tf.zeros([], dtype=x.dtype)
def _inverse_log_det_jacobian(self, y):
return tf.zeros([], dtype=y.dtype)
def _forward_event_shape(self, in_shape):
n = prefer_static.size(in_shape)
return in_shape + prefer_static.one_hot(n - 2, depth=n, dtype=tf.int32)
def _inverse_event_shape(self, out_shape):
n = prefer_static.size(out_shape)
return out_shape - prefer_static.one_hot(n - 2, depth=n, dtype=tf.int32)
tril_bijector = tfb.FillScaleTriL(diag_bijector=tfb.Softplus())
learnable_mix_mvntril_fixed_first = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(
logits=tfp.util.TransformedVariable(
# Changing the `1.` intializes with a geometric decay.
-tf.math.log(1.) * tf.range(num_components, dtype=tf.float32),
bijector=tfb.Pad(paddings=[(1, 0)]),
name='logits')),
components_distribution=tfd.MultivariateNormalTriL(
loc=tfp.util.TransformedVariable(
tf.zeros([num_components, event_size]),
bijector=tfb.Pad(paddings=[(1, 0)], axis=-2),
name='loc'),
scale_tril=tfp.util.TransformedVariable(
10. * tf.eye(event_size, batch_shape=[num_components]),
bijector=tfb.Chain([tril_bijector, PadEye(tril_bijector)]),
name='scale_tril')),
name='learnable_mix_mvntril_fixed_first')
print(learnable_mix_mvntril_fixed_first)
print(learnable_mix_mvntril_fixed_first.trainable_variables)