Copyright 2019 DeepMind Technologies Limited
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
This colab allows reproducing the plots in Figures 2 and 3 in Section 3 of the paper [1]. We consider a particular instance of the stochastic gradient problem, eqn. (10). We would like to stochastically estimate the following quantity:
$\eta = \nabla_{\theta} \int \mathcal{N}(x|\mu, \sigma^2) f(x; k) dx; \quad \theta \in \{\mu, \sigma\}; \quad f \in \{(x-k)^2, \exp(-kx^2), \cos(kx)\}.$
Here the measure is a Gaussian distribution and the cost function is univariate.
In this experiment we consider several gradient estimators:
Since all the estimators are unbiased (have the same expectation), we compare the variance of these gradient estimators. A lower-variance estimator is almost universally preferred to a higher-variance one. For this simple univariate problem, we compute the variance via numerical integration to remove any noise in the measurements.
In [0]:
import numpy as np
import scipy.stats
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
sns.set_context('paper', font_scale=2.0, rc={'lines.linewidth': 2.0})
sns.set_style('whitegrid')
# We use INTEGRATION_LIMIT instead of infinity in integration limits
INTEGRATION_LIMIT = 10.
# Threshold for testing the unbiasedness of estimators
EPS = 1e-4
# Whether to save the resulting plots on disk
SAVE_PLOTS = True
In [0]:
class SquareCost(object):
"""The cost f(x; k) = (x - k)^2"""
name = 'square'
def __init__(self, k):
self.k = k
def value(self, x):
return (x - self.k) ** 2
def derivative(self, x):
return 2 * (x - self.k)
class CosineCost(object):
"""The cost f(x; k) = cos kx"""
name = 'cos'
def __init__(self, k):
self.k = k
def value(self, x):
return np.cos(self.k * x)
def derivative(self, x):
return -self.k * np.sin(self.k * x)
class ExponentialCost(object):
"""The cost f(x; k) = exp(-k x^2)"""
name = 'exp'
def __init__(self, k):
self.k = k
def value(self, x):
return np.exp(-self.k * x ** 2)
def derivative(self, x):
return (-2 * self.k * x) * np.exp(-self.k * x ** 2)
In [0]:
class Normal(object):
"""Univariate Normal (Gaussian) measure."""
def __init__(self, mean, std, verify_unbiasedness):
self.distrib = scipy.stats.norm(loc=mean, scale=std)
self.mean = mean
self.std = std
self.verify_unbiasedness = verify_unbiasedness
def expect(self, g):
"""Computes the mean: E_p(x) g(x)"""
return scipy.integrate.quad(lambda x: self.distrib.pdf(x) * g(x),
-INTEGRATION_LIMIT, INTEGRATION_LIMIT)
def var(self, g, expect_g):
"""Compute the variance given the mean: E_p(x) (g(x) - E g(x))^2"""
if self.verify_unbiasedness:
assert (self.expect(g)[0] - expect_g) ** 2 < EPS
return self.expect(lambda x: (g(x) - expect_g) ** 2)
def cov(self, g, expect_g, h, expect_h):
"""Computes the covariance of two functions given their means:
E_p(x) (f(x) - E f(x)) (g(x) - E g(x))
"""
if self.verify_unbiasedness:
assert (self.expect(g)[0] - expect_g) ** 2 < EPS
assert (self.expect(h)[0] - expect_h) ** 2 < EPS
return self.expect(lambda x: (g(x) - expect_g) * (h(x) - expect_h))
def dlogpdf_dmean(self, x):
"""Computes the score function for mean: \nabla_mean \log p(x; mean, std)
The score function is part of the score function estimator, see eqn. (13)
"""
return (x - self.mean) / self.std ** 2
def dlogpdf_dstd(self, x):
"""Computes the score function for the std: \nabla_std \log p(x; mean, std)
The score function is part of the score function estimator, see eqn. (13)
"""
return -(((self.mean + self.std - x) *
(-self.mean + self.std + x)) / self.std ** 3)
def dx_dmean(self, x):
"""Computes \nabla_mean x.
This is part of the pathwise estimator, see eqn. (35b).
For derivation, see eqn. (37).
"""
return 1.
def dx_dstd(self, x):
"""Computes \nabla_std x.
This is part of the pathwise estimator, see eqn. (35b).
For derivation, see eqn. (37).
"""
return (x - self.mean) / self.std
class StandardWeibull(object):
"""Weibull(2, 0.5) is a distribution used for measure-valued derivative w.r.t.
Normal mean.
See equation (46) for the derivation. This distribution has a density
function x * exp(-x^2 / 2) for x > 0
"""
def __init__(self, verify_unbiasedness):
self.verify_unbiasedness = verify_unbiasedness
def expect(self, g):
"""Computes the mean: E_Weibull(x) g(x)"""
weibull_pdf = lambda x: x * np.exp(-0.5 * x ** 2)
return scipy.integrate.quad(lambda x: weibull_pdf(x) * g(x),
0, INTEGRATION_LIMIT)
def var(self, g, expect_g):
"""Compute the variance given the mean: E_Weibull(x) (g(x) - E g(x))^2"""
if self.verify_unbiasedness:
assert (self.expect(g)[0] - expect_g) ** 2 < EPS
return self.expect(lambda x: (g(x) - expect_g) ** 2)
class StandardDsMaxwellCoupledWithNormal(object):
"""This is standard double-sided Maxwell distribution coupled with
standard Normal distribution. This is a bivariate distribution which is used
for measure-valued derivative w.r.t. Normal standard deviation, see Table 1.
Standard double-sided Maxwell distribution has the density function
x^2 exp(-x^2 / 2) / sqrt(2 pi) for x \in R.
To reduce the variance of the estimator, we couple the positve
(double-sided Maxwell) and negative (Gaussian) parts of the estimator.
See Section 7.2 for discussion of this idea. Technically, this is achieved
by representing a standard Normal sample as (m*u),
where m ~ DSMaxwell and u ~ U[0, 1].
"""
def __init__(self, verify_unbiasedness):
self.verify_unbiasedness = verify_unbiasedness
def expect(self, g):
"""Computes the mean E_p(m, n) g(m, n) where m has a marginal DS-Maxwell
distribution and n has a marginal Normal distribution."""
def ds_maxwell_pdf(x):
return x ** 2 * np.exp(-0.5 * x ** 2) / np.sqrt(2 * np.pi)
return scipy.integrate.dblquad(
# m: Double Sided Maxwell, u: U[0, 1]
# The PDF of U[0, 1] is constant 1.
lambda m, u: ds_maxwell_pdf(m) * g(m, m * u),
# Limits for Uniform
0, 1,
# Limits for Double Sided Maxwell. Infinity is not supported by dblquad.
lambda x: -INTEGRATION_LIMIT, lambda x: INTEGRATION_LIMIT,
)
def var(self, g, expect_g):
"""Computes the variance E_p(m, n) (g(m, n) - E g(m, n)), where m has
a marginal DS-Maxwell distribution and n has a marginal Normal
distribution."""
if self.verify_unbiasedness:
assert (self.expect(g)[0] - expect_g) ** 2 < EPS
return self.expect(lambda m, n: (g(m, n) - expect_g) ** 2)
In [0]:
def numerical_integration(Cost, k, mean, std, verify_unbiasedness=False):
"""This function numerically evaluates the variance of gradient estimators.
Arguments:
Cost: the class of a cost function
k: a list/NumPy vector of values for the cost parameter k
mean: a scalar parameter of the Normal measure
std: a scalar parameter of the Normal measure
verify_unbiasedness: if True, perform additional asserts that verify
that the estimators are unbiased
Returns:
A dictionary {key: NumPy array}. The keys have the form var_..., where ...
is the name of the estimator. The dimensions of the NumPy arrays are
[len(k), 2, 2], where the second dimension is [dmean, dstd], and the last
dimension is [value, integration_error].
"""
measure = Normal(mean, std, verify_unbiasedness)
weibull = StandardWeibull(verify_unbiasedness)
ds_maxwell_coupled_with_normal = StandardDsMaxwellCoupledWithNormal(
verify_unbiasedness)
ret = {}
for key in ['var_sf',
'var_sf_mean_baseline',
'var_sf_optimal_baseline',
'var_pathwise',
'var_measure_valued_coupled']:
ret[key] = np.zeros([len(k), 2, 2])
for i in range(len(k)):
cost = Cost(k[i])
expect_loss = measure.expect(cost.value)[0]
# Compute $\nabla_{\theta} \int \mathcal{N}(x|\mu, \sigma^2) f(x; k) dx$
# using the score-function estimator
d_expect_loss = [
measure.expect(lambda x: cost.value(x) * measure.dlogpdf_dmean(x))[0],
measure.expect(lambda x: cost.value(x) * measure.dlogpdf_dstd(x))[0]
]
# Variance of the score-function estimator: Section 4, eqn. (13)
ret['var_sf'][i] = [
measure.var(lambda x: cost.value(x) * measure.dlogpdf_dmean(x),
d_expect_loss[0]),
measure.var(lambda x: cost.value(x) * measure.dlogpdf_dstd(x),
d_expect_loss[1])
]
# Variance of the score-function estimator with the mean baseline
# Section 4, eqn. (14)
ret['var_sf_mean_baseline'][i] = [
measure.var(lambda x: (cost.value(x) - expect_loss) * measure.dlogpdf_dmean(x),
d_expect_loss[0]),
measure.var(lambda x: (cost.value(x) - expect_loss) * measure.dlogpdf_dstd(x),
d_expect_loss[1])
]
# Computes the optimal baseline for the score-function estimator
# using Section 7.4.1, eqn. (65).
# Note that it has different values for mean and std.
optimal_baseline = [
(measure.cov(measure.dlogpdf_dmean, 0.,
lambda x: cost.value(x) * measure.dlogpdf_dmean(x),
d_expect_loss[0])[0]
/ measure.var(measure.dlogpdf_dmean, 0.)[0]),
(measure.cov(measure.dlogpdf_dstd, 0.,
lambda x: cost.value(x) * measure.dlogpdf_dstd(x),
d_expect_loss[1])[0]
/ measure.var(measure.dlogpdf_dstd, 0.)[0])
]
# Variance of the score-function estimator with the optimal baseline
# Section 4, eqn. (14)
ret['var_sf_optimal_baseline'][i] = [
measure.var(lambda x: (cost.value(x) - optimal_baseline[0]) * measure.dlogpdf_dmean(x),
d_expect_loss[0]),
measure.var(lambda x: (cost.value(x) - optimal_baseline[1]) * measure.dlogpdf_dstd(x),
d_expect_loss[1])
]
# Variance of the pathwise estimator. Here we use the "implicit" form of the
# estimator that allows reusing the same Gaussian measure.
# See Section 5, eqn. (35) for details
ret['var_pathwise'][i] = [
measure.var(lambda x: cost.derivative(x) * measure.dx_dmean(x),
d_expect_loss[0]),
measure.var(lambda x: cost.derivative(x) * measure.dx_dstd(x),
d_expect_loss[1])
]
# Variance of the measure-valued gradient estimator (Section 6, eqn. (44),
# Table 1) with variance reduction via coupling (Section 7.2)
ret['var_measure_valued_coupled'][i] = [
# We couple the Weibulls from the positive and negative parts of the
# estimator simply by reusing the value of the Weibull
weibull.var(
lambda x: (cost.value(mean + std * x) - cost.value(mean - std * x)) / (np.sqrt(2 * np.pi) * std), d_expect_loss[0]),
# See Section 7.2 and documentation of StandardDsMaxwellCoupledWithNormal
# for details on this coupling. Here m ~ DS-Maxwell, n ~ Normal(0, 1)
ds_maxwell_coupled_with_normal.var(
lambda m, n: (cost.value(m * std + mean) - cost.value(n * std + mean)) / std, d_expect_loss[1])
]
return ret
In [0]:
def plot(k, ret, param_idx, logx, logy, ylabel, ylim, filename, xticks=None):
plt.figure(figsize=[8, 5])
plt.plot(k, ret['var_sf'][:, param_idx, 0],
label='Score function')
# plt.plot(k, ret['var_sf_mean_baseline'][:, param_idx, 0],
# label='Score function + mean baseline')
plt.plot(k, ret['var_sf_optimal_baseline'][:, param_idx, 0],
label='Score function + variance reduction')
plt.plot(k, ret['var_pathwise'][:, param_idx, 0],
label='Pathwise')
plt.plot(k, ret['var_measure_valued_coupled'][:, param_idx, 0],
label='Measure-valued + variance reduction')
plt.xlabel(r'$k$')
plt.ylabel(ylabel)
plt.xlim([np.min(k), np.max(k)])
plt.ylim(ylim)
if logx:
plt.xscale('log')
if logy:
plt.yscale('log')
if xticks is not None:
plt.xticks(xticks)
x_axis = plt.gca().get_xaxis()
x_axis.set_ticklabels(xticks)
x_axis.set_major_formatter(matplotlib.ticker.ScalarFormatter())
x_axis.set_minor_formatter(matplotlib.ticker.NullFormatter())
if SAVE_PLOTS:
plt.savefig(filename, dpi=200, transparent=True)
return plt.gca()
def plot_cost_cartoon(Cost, k, x, xticks, yticks, ylim, filename):
f, axes = plt.subplots(1, 3, sharey='row', figsize=[12, 2])
for i in range(len(k)):
axes[i].plot(x, Cost(k[i]).value(x),
color='k', label='Value of the cost')
axes[i].plot(x, Cost(k[i]).derivative(x),
color='k', linestyle='--', label='Derivative of the cost')
axes[i].axis('on')
axes[i].grid(False)
axes[i].xaxis.set_tick_params(length=0)
axes[i].xaxis.set_ticks(xticks)
axes[i].yaxis.set_tick_params(length=0)
axes[i].yaxis.set_ticks(yticks)
axes[i].set_frame_on(False)
axes[0].set_ylim(ylim)
f.tight_layout()
if SAVE_PLOTS:
f.savefig(filename, dpi=200, transparent=True)
return axes
In [0]:
for Cost in [SquareCost, CosineCost, ExponentialCost]:
print(Cost.name)
ret = numerical_integration(
Cost, k=[0.1, 1., 10.], mean=1, std=1.5, verify_unbiasedness=True)
print('Maximum integration error: {}'.format(
max(np.max(v[..., 1]) for v in ret.values())))
In [0]:
Cost = SquareCost
k = np.linspace(-3., 3., 100)
ret = numerical_integration(Cost, k, mean=1, std=1)
print('Maximum integration error: {}'.format(
max(np.max(v[..., 1]) for v in ret.values())))
In [0]:
plot(
k, ret, param_idx=0,
logx=False, logy=True, ylabel=r'Variance of the estimator for $\mu$',
ylim=[1., 1e3], filename='variance_mu_{}.pdf'.format(Cost.name))
plot_ax = plot(
k, ret, param_idx=1,
logx=False, logy=True, ylabel=r'Variance of the estimator for $\sigma$',
ylim=[1., 1e3], filename='variance_sigma_{}.pdf'.format(Cost.name))
cartoon_ax = plot_cost_cartoon(
Cost, k=[np.min(k), 0, np.max(k)], x=np.linspace(-5., 5., 100),
xticks=[-5, 0, 5], yticks=[-2, 0, 5], ylim=[-2, 5],
filename='costs_{}.pdf'.format(Cost.name))
In [0]:
Cost = ExponentialCost
k = np.logspace(np.log10(0.1), np.log10(10.), 100)
ret = numerical_integration(Cost, k, mean=1, std=1)
print('Maximum integration error: {}'.format(
max(np.max(v[..., 1]) for v in ret.values())))
In [0]:
plot(
k, ret, param_idx=0,
logx=True, logy=True, ylabel=r'Variance of the estimator for $\mu$',
ylim=[1e-3, 1], xticks=[0.1, 1, 10],
filename='variance_mu_{}.pdf'.format(Cost.name))
plot_ax = plot(
k, ret, param_idx=1,
logx=True, logy=True, ylabel=r'Variance of the estimator for $\sigma$',
ylim=[1e-3, 1], xticks=[0.1, 1, 10],
filename='variance_sigma_{}.pdf'.format(Cost.name))
cartoon_ax = plot_cost_cartoon(
Cost, k=[np.min(k), 1, np.max(k)], x=np.linspace(-3., 3., 100),
xticks=[-3, 0, 3], yticks=[-1, 0, 1], ylim=[-1.1, 1.1],
filename='costs_{}.pdf'.format(Cost.name))
In [0]:
Cost = CosineCost
k = np.logspace(np.log10(0.5), np.log10(5.), 100)
ret = numerical_integration(Cost, k, mean=1, std=1)
print('Maximum integration error: {}'.format(
max(np.max(v[..., 1]) for v in ret.values())))
In [0]:
plot(
k, ret, param_idx=0,
logx=True, logy=True, ylabel=r'Variance of the estimator for $\mu$',
ylim=[0.005, 10], xticks=[0.5, 1, 2, 5],
filename='variance_mu_{}.pdf'.format(Cost.name))
plot_ax = plot(
k, ret, param_idx=1,
logx=True, logy=True, ylabel=r'Variance of the estimator for $\sigma$',
ylim=[0.1, 10], xticks=[0.5, 1, 2, 5],
filename='variance_sigma_{}.pdf'.format(Cost.name))
cartoon_ax = plot_cost_cartoon(
Cost,
k=[np.min(k),
10 ** ((np.log10(np.min(k)) + np.log10(np.min(k))) / 2),
np.max(k)],
x=np.linspace(-3., 3., 100),
xticks=[-3, 0, 3], yticks=[-3, 0, 3], ylim=[-3, 3],
filename='costs_{}.pdf'.format(Cost.name))
In [0]:
plt.figure(figsize=[22, 1])
plt.axis('off')
plt.grid(False)
plt.legend(*plot_ax.get_legend_handles_labels(), loc='center',
frameon=False, ncol=5)
if SAVE_PLOTS:
filename = 'estimators_legend.pdf'
plt.savefig(filename, dpi=200, transparent=True)
plt.figure(figsize=[22, 1])
plt.axis('off')
plt.grid(False)
plt.legend(*cartoon_ax[0].get_legend_handles_labels(), loc='center',
frameon=False, ncol=5)
if SAVE_PLOTS:
filename = 'costs_legend.pdf'
plt.savefig(filename, dpi=200, transparent=True)