Optimization

In this notebook, we explore various algorithms for solving x* = argmin_{x in R^D} f(x), where f(x) is a differentiable cost function.

TOC


In [0]:
import sklearn
import scipy
import scipy.optimize
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
import itertools
import time
from functools import partial

import os


import numpy as np
from scipy.special import logsumexp
np.set_printoptions(precision=3)


# We make some wrappers around random number generation
# so it works even if we switch from numpy to JAX
import numpy as onp # original numpy

def set_seed(seed):
    onp.random.seed(seed)
    
def randn(*args):
    return onp.random.randn(*args)
        
def randperm(args):
    return onp.random.permutation(args)

In [0]:
import torch
import torchvision
print("torch version {}".format(torch.__version__))
if torch.cuda.is_available():
  print(torch.cuda.get_device_name(0))
  print("current device {}".format(torch.cuda.current_device()))
else:
  print("Torch cannot find GPU")

def set_seed(seed):
  onp.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
#torch.backends.cudnn.benchmark = True


torch version 1.1.0
Tesla T4
current device 0

In [0]:
# Tensorflow 2.0
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass
import tensorflow as tf
from tensorflow import keras
print("tf version {}".format(tf.__version__))
if tf.test.is_gpu_available():
    print(tf.test.gpu_device_name())
else:
    print("TF cannot find GPU")


TensorFlow 2.x selected.
tf version 2.0.0-beta1
/device:GPU:0

In [0]:
# JAX (https://github.com/google/jax)
!pip install --upgrade -q https://storage.googleapis.com/jax-releases/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-$(pip search jaxlib | grep -oP '[0-9\.]+' | head -n 1)-cp36-none-linux_x86_64.whl
!pip install --upgrade -q jax

import jax
import jax.numpy as np
import numpy as onp
from jax.scipy.special import logsumexp
from jax import grad, hessian, jacfwd, jacrev, jit, vmap
from jax.experimental import optimizers
print("jax version {}".format(jax.__version__))


jax version 0.1.43

Automatic differentiation

In this section we illustrate various AD libraries by using them to derive the gradient of the negative log likelihood for binary logistic regression applied to the Iris dataset. We compare to the manual numpy implementation.

As a minor detail, we evaluate the gradient of the NLL of the test data with the parameters set to their training MLE, in order to get an interesting signal; using a random weight vector makes the dynamic range of the output harder to see.


In [0]:
# Fit the model to a dataset, so we have an "interesting" parameter vector to use.

import sklearn.datasets
from sklearn.model_selection import train_test_split

iris = sklearn.datasets.load_iris()
X = iris["data"]
y = (iris["target"] == 2).astype(onp.int)  # 1 if Iris-Virginica, else 0'
N, D = X.shape # 150, 4


X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.33, random_state=42)

from sklearn.linear_model import LogisticRegression

# We set C to a large number to turn off regularization.
# We don't fit the bias term to simplify the comparison below.
log_reg = LogisticRegression(solver="lbfgs", C=1e5, fit_intercept=False)
log_reg.fit(X_train, y_train)
w_mle_sklearn = np.ravel(log_reg.coef_)
w = w_mle_sklearn

In [0]:
## Compute gradient of loss "by hand" using numpy


def BCE_with_logits(logits, targets):
    N = logits.shape[0]
    logits = logits.reshape(N,1)
    logits_plus = np.hstack([np.zeros((N,1)), logits]) # e^0=1
    logits_minus = np.hstack([np.zeros((N,1)), -logits])
    logp1 = -logsumexp(logits_minus, axis=1)
    logp0 = -logsumexp(logits_plus, axis=1)
    logprobs = logp1 * targets + logp0 * (1-targets)
    return -np.sum(logprobs)/N

# Compute using numpy
def sigmoid(x): return 0.5 * (np.tanh(x / 2.) + 1)

def predict_logit(weights, inputs):
    return np.dot(inputs, weights) # Already vectorized

def predict_prob(weights, inputs):
    return sigmoid(predict_logit(weights, inputs))

def NLL(weights, batch):
    X, y = batch
    logits = predict_logit(weights, X)
    return BCE_with_logits(logits, y)

def NLL_grad(weights, batch):
    X, y = batch
    N = X.shape[0]
    mu = predict_prob(weights, X)
    g = np.sum(np.dot(np.diag(mu - y), X), axis=0)/N
    return g

y_pred = predict_prob(w, X_test)
loss = NLL(w, (X_test, y_test))
grad_np = NLL_grad(w, (X_test, y_test))
print("params {}".format(w))
#print("pred {}".format(y_pred))
print("loss {}".format(loss))
print("grad {}".format(grad_np))


params [-4.414 -9.111  6.539 12.686]
loss 0.11824002861976624
grad [-0.235 -0.122 -0.198 -0.064]

AD in JAX

Below we use JAX to compute the gradient of the NLL for binary logistic regression. For some examples of using JAX to compute the gradients, Jacobians and Hessians of simple linear and quadratic functions, see this notebook. More details on JAX's autodiff can be found in the official autodiff cookbook.


In [0]:
grad_jax = grad(NLL)(w, (X_test, y_test))
print("grad {}".format(grad_jax))
assert np.allclose(grad_np, grad_jax)


grad [-0.235 -0.122 -0.198 -0.064]

AD in Tensorflow

We just wrap the relevant forward computations inside GradientTape(), and then call tape.gradient(objective, [variables]).


In [0]:
w_tf = tf.Variable(np.reshape(w, (D,1)))  
x_test_tf = tf.convert_to_tensor(X_test, dtype=np.float64) 
y_test_tf = tf.convert_to_tensor(np.reshape(y_test, (-1,1)), dtype=np.float64)
with tf.GradientTape() as tape:
    logits = tf.linalg.matmul(x_test_tf, w_tf)
    y_pred = tf.math.sigmoid(logits)
    loss_batch = tf.nn.sigmoid_cross_entropy_with_logits(y_test_tf, logits)
    loss_tf = tf.reduce_mean(loss_batch, axis=0)
grad_tf = tape.gradient(loss_tf, [w_tf])
grad_tf = grad_tf[0][:,0].numpy()
assert np.allclose(grad_np, grad_tf)

print("params {}".format(w_tf))
#print("pred {}".format(y_pred))
print("loss {}".format(loss_tf))
print("grad {}".format(grad_tf))


WARNING: Logging before flag parsing goes to stderr.
W0826 04:28:46.621946 140039241475968 deprecation.py:323] From /tensorflow-2.0.0b1/python3.6/tensorflow/python/ops/nn_impl.py:182: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
params <tf.Variable 'Variable:0' shape=(4, 1) dtype=float64, numpy=
array([[-4.414],
       [-9.111],
       [ 6.539],
       [12.686]])>
loss [0.118]
grad [-0.235 -0.122 -0.198 -0.064]

AD in PyTorch

We just compute the objective, call backward() on it, and then lookup variable.grad. However, we have to specify the requires_grad=True attribute on the variable before computing the objective, so that Torch knows to record its values on its tape.


In [0]:
w_torch = torch.Tensor(np.reshape(w, [D, 1])).to(device)
w_torch.requires_grad_() 
x_test_tensor = torch.Tensor(X_test).to(device)
y_test_tensor = torch.Tensor(y_test).to(device)
y_pred = torch.sigmoid(torch.matmul(x_test_tensor, w_torch))[:,0]
criterion = torch.nn.BCELoss(reduction='mean')
loss_torch = criterion(y_pred, y_test_tensor)
loss_torch.backward()
grad_torch = w_torch.grad[:,0].cpu().numpy()
assert np.allclose(grad_np, grad_torch)

print("params {}".format(w_torch))
#print("pred {}".format(y_pred))
print("loss {}".format(loss_torch))
print("grad {}".format(grad_torch))


params tensor([[-4.4138],
        [-9.1106],
        [ 6.5387],
        [12.6857]], device='cuda:0', requires_grad=True)
loss 0.11824004352092743
grad [-0.235 -0.122 -0.198 -0.064]

Second-order, full-batch optimization

The "gold standard" of optimization is second-order methods, that leverage Hessian information. Since the Hessian has O(D^2) parameters, such methods do not scale to high-dimensional problems. However, we can sometimes approximate the Hessian using low-rank or diagonal approximations. Below we illustrate the low-rank BFGS method, and the limited-memory version of BFGS, that uses O(D H) space and O(D^2) time per step, where H is the history length.

In general, second-order methods also require exact (rather than noisy) gradients. In the context of ML, this means they are "full batch" methods, since computing the exact gradient requires evaluating the loss on all the datapoints. However, for small data problems, this is feasible (and advisable).

Below we illustrate how to use LBFGS as implemented in various libraries. Other second-order optimizers have a similar API. We use the same binary logistic regression problem as above.


In [0]:
# Repeat relevant code from AD section above, for convenience.

# Dataset
import sklearn.datasets
from sklearn.model_selection import train_test_split
iris = sklearn.datasets.load_iris()
X = iris["data"]
y = (iris["target"] == 2).astype(onp.int)  # 1 if Iris-Virginica, else 0'
N, D = X.shape # 150, 4
X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.33, random_state=42)


# Sklearn estimate
from sklearn.linear_model import LogisticRegression
log_reg = LogisticRegression(solver="lbfgs", C=1e5, fit_intercept=False)
log_reg.fit(X_train, y_train)
w_mle_sklearn = np.ravel(log_reg.coef_)
w = w_mle_sklearn

# Define Model and binary cross entropy loss 
def BCE_with_logits(logits, targets):
    N = logits.shape[0]
    logits = logits.reshape(N,1)
    logits_plus = np.hstack([np.zeros((N,1)), logits]) # e^0=1
    logits_minus = np.hstack([np.zeros((N,1)), -logits])
    logp1 = -logsumexp(logits_minus, axis=1)
    logp0 = -logsumexp(logits_plus, axis=1)
    logprobs = logp1 * targets + logp0 * (1-targets)
    return -np.sum(logprobs)/N

def sigmoid(x): return 0.5 * (np.tanh(x / 2.) + 1)

def predict_logit(weights, inputs):
    return np.dot(inputs, weights) # Already vectorized

def predict_prob(weights, inputs):
    return sigmoid(predict_logit(weights, inputs))

def NLL(weights, batch):
    X, y = batch
    logits = predict_logit(weights, X)
    return BCE_with_logits(logits, y)

Scipy version

We show how to use the implementation from scipy.optimize


In [0]:
import scipy.optimize

# We manually compute gradients, but could use Jax instead
def NLL_grad(weights, batch):
    X, y = batch
    N = X.shape[0]
    mu = predict_prob(weights, X)
    g = np.sum(np.dot(np.diag(mu - y), X), axis=0)/N
    return g

def training_loss(w):
    return NLL(w, (X_train, y_train))

def training_grad(w):
    return NLL_grad(w, (X_train, y_train))

set_seed(0)
w_init = randn(D)

options={'disp': None,   'maxfun': 1000, 'maxiter': 1000}
method = 'BFGS'
w_mle_scipy = scipy.optimize.minimize(
    training_loss, w_init, jac=training_grad,
    method=method, options=options).x   

print("parameters from sklearn {}".format(w_mle_sklearn))
print("parameters from scipy-bfgs {}".format(w_mle_scipy))


parameters from sklearn [-4.414 -9.111  6.539 12.686]
parameters from scipy-bfgs [-4.417 -9.117  6.543 12.695]

In [0]:
# Limited memory version requires that we work with 64bit, since implemented in Fortran.

def training_loss2(w):
    l = NLL(w, (X_train, y_train))
    return onp.float64(l)

def training_grad2(w):
    g = NLL_grad(w, (X_train, y_train))
    return onp.asarray(g, dtype=onp.float64)
                 
set_seed(0)
w_init = randn(D)
memory = 10
options={'disp': None, 'maxcor': memory,  'maxfun': 1000, 'maxiter': 1000}
# The code also handles bound constraints, hence the name
method = 'L-BFGS-B'
w_mle_scipy = scipy.optimize.minimize(training_loss, w_init, jac=training_grad2, method=method).x 


print("parameters from sklearn {}".format(w_mle_sklearn))
print("parameters from scipy-lbfgs {}".format(w_mle_scipy))


parameters from sklearn [-4.414 -9.111  6.539 12.686]
parameters from scipy-lbfgs [-4.415 -9.114  6.54  12.691]

PyTorch version

We show how to use the version from PyTorch.optim.lbfgs.


In [0]:
# Put data into PyTorch format.
import torch
from torch.utils.data import DataLoader, TensorDataset

N, D = X_train.shape
x_train_tensor = torch.Tensor(X_train)
y_train_tensor = torch.Tensor(y_train)
data_set = TensorDataset(x_train_tensor, y_train_tensor)

In [0]:
# Define model and loss.

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = torch.nn.Linear(D, 1, bias=False) 
        
    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred
    
set_seed(0)
model = Model() 
criterion = torch.nn.BCELoss(reduction='mean')

In [0]:
optimizer = torch.optim.LBFGS(model.parameters(), history_size=10)
    
def closure():
    optimizer.zero_grad()
    y_pred = model(x_train_tensor)
    loss = criterion(y_pred, y_train_tensor)
    #print('loss:', loss.item())
    loss.backward()
    return loss

max_iter = 10
for i in range(max_iter):
    loss = optimizer.step(closure)

params = list(model.parameters())
w_torch_bfgs = params[0][0].detach().numpy() #(D,) vector
print("parameters from sklearn {}".format(w_mle_sklearn))
print("parameters from torch-bfgs {}".format(w_torch_bfgs))


parameters from sklearn [-4.414 -9.111  6.539 12.686]
parameters from torch-bfgs [-4.415 -9.114  6.54  12.691]

TF version

There is also a version of LBFGS in TF

Stochastic gradient descent

In this section we illustrate how to implement SGD. We apply it to a simple convex problem, namely MLE for binary logistic regression on the small iris dataset, so we can compare to the exact batch methods we illustrated above.

Numpy version

We show a minimal implementation of SGD using vanilla numpy. For convenience, we use TFDS to create a stream of mini-batches. We compute gradients by hand, but can use any AD library.


In [0]:
import tensorflow_datasets as tfds


def make_batcher(batch_size, X, y):
  def get_batches():
    # Convert numpy arrays to tfds
    ds = tf.data.Dataset.from_tensor_slices({"X": X, "y": y})
    ds = ds.batch(batch_size)
    # convert tfds into an iterable of dict of NumPy arrays
    return tfds.as_numpy(ds)
  return get_batches

batcher = make_batcher(2, X_train, y_train)

for epoch in range(2):
  print('epoch {}'.format(epoch))
  for batch in batcher():
    x, y = batch["X"], batch["y"]
    #print(x.shape)


epoch 0
epoch 1

In [0]:
def sgd(params, loss_fn, grad_loss_fn, get_batches_as_dict, max_epochs, lr):
    print_every = max(1, int(0.1*max_epochs))
    for epoch in range(max_epochs):
        epoch_loss = 0.0
        for batch_dict in get_batches_as_dict():
            x, y = batch_dict["X"], batch_dict["y"]
            batch = (x, y)
            batch_grad = grad_loss_fn(params, batch)
            params = params - lr*batch_grad
            batch_loss = loss_fn(params, batch) # Average loss within this batch
            epoch_loss += batch_loss
        if epoch % print_every == 0:
            print('Epoch {}, Loss {}'.format(epoch, epoch_loss))
    return params,

In [0]:
set_seed(0)
D = X_train.shape[1]
w_init = onp.random.randn(D)

def training_loss2(w):
    l = NLL(w, (X_train, y_train))
    return onp.float64(l)

def training_grad2(w):
    g = NLL_grad(w, (X_train, y_train))
    return onp.asarray(g, dtype=onp.float64)

max_epochs = 5
lr = 0.1
batch_size = 10
batcher = make_batcher(batch_size, X_train, y_train)
w_mle_sgd = sgd(w_init, NLL, NLL_grad, batcher, max_epochs, lr)
print(w_mle_sgd)


Epoch 0, Loss 21.775604248046875
Epoch 1, Loss 3.2622179985046387
Epoch 2, Loss 3.1074540615081787
Epoch 3, Loss 2.9816956520080566
Epoch 4, Loss 2.875518798828125
(DeviceArray([-0.399, -0.919,  0.311,  2.174], dtype=float32),)

Jax version

JAX has a small optimization library focused on stochastic first-order optimizers. Every optimizer is modeled as an (init_fun, update_fun, get_params) triple of functions. The init_fun is used to initialize the optimizer state, which could include things like momentum variables, and the update_fun accepts a gradient and an optimizer state to produce a new optimizer state. The get_params function extracts the current iterate (i.e. the current parameters) from the optimizer state. The parameters being optimized can be ndarrays or arbitrarily-nested list/tuple/dict structures, so you can store your parameters however you’d like.

Below we show how to reproduce our numpy code using this library.


In [0]:
# Version that uses JAX optimization library


#@jit
def sgd_jax(params, loss_fn, get_batches, max_epochs, opt_init, opt_update, get_params):
    loss_history = []
    opt_state = opt_init(params)
    
    #@jit
    def update(i, opt_state, batch):
        params = get_params(opt_state)
        g = grad(loss_fn)(params, batch)
        return opt_update(i, g, opt_state) 
    
    print_every = max(1, int(0.1*max_epochs))
    total_steps = 0
    for epoch in range(max_epochs):
        epoch_loss = 0.0
        for batch_dict in get_batches():
            X, y = batch_dict["X"], batch_dict["y"]
            batch = (X, y)
            total_steps += 1
            opt_state = update(total_steps, opt_state, batch)
        params = get_params(opt_state)
        train_loss = onp.float(loss_fn(params, batch))
        loss_history.append(train_loss)
        if epoch % print_every == 0:
            print('Epoch {}, train NLL {}'.format(epoch, train_loss))
    return params, loss_history

In [0]:
b=list(batcher())
X, y = b[0]["X"], b[0]["y"]
X.shape
batch = (X, y)
params= w_init
onp.float(NLL(params, batch))
g = grad(NLL)(params, batch)

In [0]:
# JAX with constant LR should match our minimal version of SGD


schedule = optimizers.constant(step_size=lr)
opt_init, opt_update, get_params = optimizers.sgd(step_size=schedule)

w_mle_sgd2, history = sgd_jax(w_init, NLL, batcher, max_epochs, 
                              opt_init, opt_update, get_params)
print(w_mle_sgd2)
print(history)


Epoch 0, train NLL 0.3694833219051361
Epoch 1, train NLL 0.3485594689846039
Epoch 2, train NLL 0.33153384923934937
Epoch 3, train NLL 0.31704843044281006
Epoch 4, train NLL 0.3043736517429352
[-0.399 -0.919  0.311  2.174]
[0.3694833219051361, 0.3485594689846039, 0.33153384923934937, 0.31704843044281006, 0.3043736517429352]

In [0]: