In [1]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Date: 06/08/2020
This is a self-contained notebook that reproduces the toy example in Fig. 1 of the Guided ES paper.
In [2]:
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax.experimental import optimizers
print(f'Jax version: {jax.__version__}')
In [3]:
# Generate problem data.
rs = np.random.RandomState(seed=0)
m = 2000 # Number of data points.
n = 1000 # Number of variables.
A = rs.randn(m, n)
b = rs.randn(m, 1)
xstar = np.linalg.lstsq(A, b, rcond=None)[0]
f_star = (0.5/float(m)) * np.linalg.norm(np.dot(A, xstar) - b) ** 2
A = jnp.array(A)
b = jnp.array(b)
def l2_normalize(vector):
return vector / jnp.linalg.norm(vector)
# This is a bias vector that will be added to the gradient
grad_bias = 1.0 * l2_normalize(jnp.array(rs.randn(n, 1)))
def loss_and_grad_fun(x):
residual = jnp.dot(A, x) - b
loss = 0.5 * jnp.linalg.norm(residual) ** 2 / float(m)
# The 'gradient' that we observe is a noisy, biased version of the true gradient.
# This is meant to mimic scenarios where we only have access to biased gradients.
err = jnp.dot(A.T, residual) / float(m)
gradient = err + grad_bias * jnp.linalg.norm(err)
return loss, gradient
In [4]:
init_opt, update_opt, get_params = optimizers.sgd(5e-3)
@jax.jit
def step_fun(step, opt_state):
x = get_params(opt_state)
loss, gradient = loss_and_grad_fun(x)
opt_state = update_opt(step, gradient, opt_state)
return loss, opt_state
In [5]:
%%time
x = jnp.zeros((n, 1))
opt_state = init_opt(x)
fobj = []
for step in range(10000):
loss, opt_state = step_fun(step, opt_state)
fobj.append(loss)
# Store training curve for plotting later.
f_gd = np.array(fobj).copy()
In [6]:
# Hyperparameters for Vanilla ES
sigma = 0.1
beta = 1.0
learning_rate = 0.2
# Defines the distribution for sampling parameter perturbations.
scale = sigma / np.sqrt(n)
def sample(prng_key):
return scale * jax.random.normal(prng_key, shape=(n, 1))
init_opt, update_opt, get_params = optimizers.sgd(0.2)
base_key = jax.random.PRNGKey(0)
@jax.jit
def step_fun(step, opt_state):
prng_key = jax.random.fold_in(base_key, step)
epsilon = sample(prng_key)
x = get_params(opt_state)
# We utilize antithetic (positive and negative) samples.
f_pos, _ = loss_and_grad_fun(x + epsilon)
f_neg, _ = loss_and_grad_fun(x - epsilon)
# This update is a stochastic finite difference estimate of the true gradient.
update = (beta / (2 * sigma ** 2)) * (f_pos - f_neg) * epsilon
opt_state = update_opt(step, update, opt_state)
return loss_and_grad_fun(x)[0], opt_state
In [7]:
%%time
x = jnp.zeros((n, 1))
opt_state = init_opt(x)
# Run the optmizer.
fobj = []
for step in range(10000):
loss, opt_state = step_fun(step, opt_state)
fobj.append(loss)
# Store training curve for plotting later.
f_ves = np.array(fobj).copy()
Guided ES is our proposed method. It uses a diagonal plus low-rank covariance matrix for drawing perturbations, where the low-rank subspace is spanned by the available gradient information.
This allows it to incorporate the biased gradient information, while still minimizing the true loss function.
In [8]:
# Hyperparameters for Guided ES
sigma = 0.1
alpha = 0.5
beta = 1.0
k = 1 # Defines the dimensionality of the low-rank subspace.
# Defines parameters of the distribution for sampling perturbations.
a = sigma * np.sqrt(alpha / float(n))
c = sigma * np.sqrt((1. - alpha) / float(k))
# Defines the distribution for sampling parameter perturbations.
def sample(prng_key, gradient_subspace):
keys = jax.random.split(prng_key, 2)
epsilon_full = jax.random.normal(keys[0], shape=(n, 1))
epsilon_subspace = jax.random.normal(keys[1], shape=(k, 1))
Q, _ = jnp.linalg.qr(gradient_subspace)
epsilon = a * epsilon_full + c * jnp.dot(Q, epsilon_subspace)
return epsilon
init_opt, update_opt, get_params = optimizers.sgd(0.2)
base_key = jax.random.PRNGKey(0)
@jax.jit
def step_fun(step, opt_state):
prng_key = jax.random.fold_in(base_key, step)
x = get_params(opt_state)
loss, gradient = loss_and_grad_fun(x)
epsilon = sample(prng_key, gradient)
# We utilize antithetic (positive and negative) samples.
f_pos, _ = loss_and_grad_fun(x + epsilon)
f_neg, _ = loss_and_grad_fun(x - epsilon)
# This update is a stochastic finite difference estimate of the true gradient.
update = (beta / (2 * sigma ** 2)) * (f_pos - f_neg) * epsilon
opt_state = update_opt(step, update, opt_state)
return loss, opt_state
In [9]:
%%time
x = jnp.zeros((n, 1))
opt_state = init_opt(x)
# Run the optmizer.
fobj = []
for step in range(10000):
loss, opt_state = step_fun(step, opt_state)
fobj.append(loss)
# Store training curve for plotting later.
f_ges = np.array(fobj).copy()
In [10]:
COLORS = {'ges': '#7570b3', 'ves': '#1b9e77', 'sgdm': '#d95f02'}
plt.figure(figsize=(8, 6))
plt.plot(f_ves - f_star, color=COLORS['ves'], label='Vanilla ES')
plt.plot(f_gd - f_star, color=COLORS['sgdm'], label='Grad. Descent')
plt.plot(f_ges - f_star, color=COLORS['ges'], label='Guided ES')
plt.legend(fontsize=16, loc=0)
plt.xlabel('Iteration', fontsize=16)
plt.ylabel('Loss', fontsize=16)
plt.title('Demo of Guided Evolutionary Strategies', fontsize=16);
In [ ]: