In [0]:
import sklearn
import scipy
import scipy.optimize
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
import itertools
import time
from functools import partial
import os
import numpy as np
from scipy.special import logsumexp
np.set_printoptions(precision=3)
# We make some wrappers around random number generation
# so it works even if we switch from numpy to JAX
import numpy as onp # original numpy
def set_seed(seed):
onp.random.seed(seed)
def randn(*args):
return onp.random.randn(*args)
def randperm(args):
return onp.random.permutation(args)
In [0]:
import torch
import torchvision
print("torch version {}".format(torch.__version__))
if torch.cuda.is_available():
print(torch.cuda.get_device_name(0))
print("current device {}".format(torch.cuda.current_device()))
else:
print("Torch cannot find GPU")
def set_seed(seed):
onp.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
#torch.backends.cudnn.benchmark = True
In [0]:
# Tensorflow 2.0
try:
# %tensorflow_version only exists in Colab.
%tensorflow_version 2.x
except Exception:
pass
import tensorflow as tf
from tensorflow import keras
print("tf version {}".format(tf.__version__))
if tf.test.is_gpu_available():
print(tf.test.gpu_device_name())
else:
print("TF cannot find GPU")
In [0]:
# JAX (https://github.com/google/jax)
!pip install --upgrade -q https://storage.googleapis.com/jax-releases/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-$(pip search jaxlib | grep -oP '[0-9\.]+' | head -n 1)-cp36-none-linux_x86_64.whl
!pip install --upgrade -q jax
import jax
import jax.numpy as np
import numpy as onp
from jax.scipy.special import logsumexp
from jax import grad, hessian, jacfwd, jacrev, jit, vmap
from jax.experimental import optimizers
print("jax version {}".format(jax.__version__))
In this section we illustrate various AD libraries by using them to derive the gradient of the negative log likelihood for binary logistic regression applied to the Iris dataset. We compare to the manual numpy implementation.
As a minor detail, we evaluate the gradient of the NLL of the test data with the parameters set to their training MLE, in order to get an interesting signal; using a random weight vector makes the dynamic range of the output harder to see.
In [0]:
# Fit the model to a dataset, so we have an "interesting" parameter vector to use.
import sklearn.datasets
from sklearn.model_selection import train_test_split
iris = sklearn.datasets.load_iris()
X = iris["data"]
y = (iris["target"] == 2).astype(onp.int) # 1 if Iris-Virginica, else 0'
N, D = X.shape # 150, 4
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42)
from sklearn.linear_model import LogisticRegression
# We set C to a large number to turn off regularization.
# We don't fit the bias term to simplify the comparison below.
log_reg = LogisticRegression(solver="lbfgs", C=1e5, fit_intercept=False)
log_reg.fit(X_train, y_train)
w_mle_sklearn = np.ravel(log_reg.coef_)
w = w_mle_sklearn
In [0]:
## Compute gradient of loss "by hand" using numpy
def BCE_with_logits(logits, targets):
N = logits.shape[0]
logits = logits.reshape(N,1)
logits_plus = np.hstack([np.zeros((N,1)), logits]) # e^0=1
logits_minus = np.hstack([np.zeros((N,1)), -logits])
logp1 = -logsumexp(logits_minus, axis=1)
logp0 = -logsumexp(logits_plus, axis=1)
logprobs = logp1 * targets + logp0 * (1-targets)
return -np.sum(logprobs)/N
# Compute using numpy
def sigmoid(x): return 0.5 * (np.tanh(x / 2.) + 1)
def predict_logit(weights, inputs):
return np.dot(inputs, weights) # Already vectorized
def predict_prob(weights, inputs):
return sigmoid(predict_logit(weights, inputs))
def NLL(weights, batch):
X, y = batch
logits = predict_logit(weights, X)
return BCE_with_logits(logits, y)
def NLL_grad(weights, batch):
X, y = batch
N = X.shape[0]
mu = predict_prob(weights, X)
g = np.sum(np.dot(np.diag(mu - y), X), axis=0)/N
return g
y_pred = predict_prob(w, X_test)
loss = NLL(w, (X_test, y_test))
grad_np = NLL_grad(w, (X_test, y_test))
print("params {}".format(w))
#print("pred {}".format(y_pred))
print("loss {}".format(loss))
print("grad {}".format(grad_np))
Below we use JAX to compute the gradient of the NLL for binary logistic regression. For some examples of using JAX to compute the gradients, Jacobians and Hessians of simple linear and quadratic functions, see this notebook. More details on JAX's autodiff can be found in the official autodiff cookbook.
In [0]:
grad_jax = grad(NLL)(w, (X_test, y_test))
print("grad {}".format(grad_jax))
assert np.allclose(grad_np, grad_jax)
In [0]:
w_tf = tf.Variable(np.reshape(w, (D,1)))
x_test_tf = tf.convert_to_tensor(X_test, dtype=np.float64)
y_test_tf = tf.convert_to_tensor(np.reshape(y_test, (-1,1)), dtype=np.float64)
with tf.GradientTape() as tape:
logits = tf.linalg.matmul(x_test_tf, w_tf)
y_pred = tf.math.sigmoid(logits)
loss_batch = tf.nn.sigmoid_cross_entropy_with_logits(y_test_tf, logits)
loss_tf = tf.reduce_mean(loss_batch, axis=0)
grad_tf = tape.gradient(loss_tf, [w_tf])
grad_tf = grad_tf[0][:,0].numpy()
assert np.allclose(grad_np, grad_tf)
print("params {}".format(w_tf))
#print("pred {}".format(y_pred))
print("loss {}".format(loss_tf))
print("grad {}".format(grad_tf))
In [0]:
w_torch = torch.Tensor(np.reshape(w, [D, 1])).to(device)
w_torch.requires_grad_()
x_test_tensor = torch.Tensor(X_test).to(device)
y_test_tensor = torch.Tensor(y_test).to(device)
y_pred = torch.sigmoid(torch.matmul(x_test_tensor, w_torch))[:,0]
criterion = torch.nn.BCELoss(reduction='mean')
loss_torch = criterion(y_pred, y_test_tensor)
loss_torch.backward()
grad_torch = w_torch.grad[:,0].cpu().numpy()
assert np.allclose(grad_np, grad_torch)
print("params {}".format(w_torch))
#print("pred {}".format(y_pred))
print("loss {}".format(loss_torch))
print("grad {}".format(grad_torch))
The "gold standard" of optimization is second-order methods, that leverage Hessian information. Since the Hessian has O(D^2) parameters, such methods do not scale to high-dimensional problems. However, we can sometimes approximate the Hessian using low-rank or diagonal approximations. Below we illustrate the low-rank BFGS method, and the limited-memory version of BFGS, that uses O(D H) space and O(D^2) time per step, where H is the history length.
In general, second-order methods also require exact (rather than noisy) gradients. In the context of ML, this means they are "full batch" methods, since computing the exact gradient requires evaluating the loss on all the datapoints. However, for small data problems, this is feasible (and advisable).
Below we illustrate how to use LBFGS as implemented in various libraries. Other second-order optimizers have a similar API. We use the same binary logistic regression problem as above.
In [0]:
# Repeat relevant code from AD section above, for convenience.
# Dataset
import sklearn.datasets
from sklearn.model_selection import train_test_split
iris = sklearn.datasets.load_iris()
X = iris["data"]
y = (iris["target"] == 2).astype(onp.int) # 1 if Iris-Virginica, else 0'
N, D = X.shape # 150, 4
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42)
# Sklearn estimate
from sklearn.linear_model import LogisticRegression
log_reg = LogisticRegression(solver="lbfgs", C=1e5, fit_intercept=False)
log_reg.fit(X_train, y_train)
w_mle_sklearn = np.ravel(log_reg.coef_)
w = w_mle_sklearn
# Define Model and binary cross entropy loss
def BCE_with_logits(logits, targets):
N = logits.shape[0]
logits = logits.reshape(N,1)
logits_plus = np.hstack([np.zeros((N,1)), logits]) # e^0=1
logits_minus = np.hstack([np.zeros((N,1)), -logits])
logp1 = -logsumexp(logits_minus, axis=1)
logp0 = -logsumexp(logits_plus, axis=1)
logprobs = logp1 * targets + logp0 * (1-targets)
return -np.sum(logprobs)/N
def sigmoid(x): return 0.5 * (np.tanh(x / 2.) + 1)
def predict_logit(weights, inputs):
return np.dot(inputs, weights) # Already vectorized
def predict_prob(weights, inputs):
return sigmoid(predict_logit(weights, inputs))
def NLL(weights, batch):
X, y = batch
logits = predict_logit(weights, X)
return BCE_with_logits(logits, y)
We show how to use the implementation from scipy.optimize
In [0]:
import scipy.optimize
# We manually compute gradients, but could use Jax instead
def NLL_grad(weights, batch):
X, y = batch
N = X.shape[0]
mu = predict_prob(weights, X)
g = np.sum(np.dot(np.diag(mu - y), X), axis=0)/N
return g
def training_loss(w):
return NLL(w, (X_train, y_train))
def training_grad(w):
return NLL_grad(w, (X_train, y_train))
set_seed(0)
w_init = randn(D)
options={'disp': None, 'maxfun': 1000, 'maxiter': 1000}
method = 'BFGS'
w_mle_scipy = scipy.optimize.minimize(
training_loss, w_init, jac=training_grad,
method=method, options=options).x
print("parameters from sklearn {}".format(w_mle_sklearn))
print("parameters from scipy-bfgs {}".format(w_mle_scipy))
In [0]:
# Limited memory version requires that we work with 64bit, since implemented in Fortran.
def training_loss2(w):
l = NLL(w, (X_train, y_train))
return onp.float64(l)
def training_grad2(w):
g = NLL_grad(w, (X_train, y_train))
return onp.asarray(g, dtype=onp.float64)
set_seed(0)
w_init = randn(D)
memory = 10
options={'disp': None, 'maxcor': memory, 'maxfun': 1000, 'maxiter': 1000}
# The code also handles bound constraints, hence the name
method = 'L-BFGS-B'
w_mle_scipy = scipy.optimize.minimize(training_loss, w_init, jac=training_grad2, method=method).x
print("parameters from sklearn {}".format(w_mle_sklearn))
print("parameters from scipy-lbfgs {}".format(w_mle_scipy))
We show how to use the version from PyTorch.optim.lbfgs.
In [0]:
# Put data into PyTorch format.
import torch
from torch.utils.data import DataLoader, TensorDataset
N, D = X_train.shape
x_train_tensor = torch.Tensor(X_train)
y_train_tensor = torch.Tensor(y_train)
data_set = TensorDataset(x_train_tensor, y_train_tensor)
In [0]:
# Define model and loss.
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = torch.nn.Linear(D, 1, bias=False)
def forward(self, x):
y_pred = torch.sigmoid(self.linear(x))
return y_pred
set_seed(0)
model = Model()
criterion = torch.nn.BCELoss(reduction='mean')
In [0]:
optimizer = torch.optim.LBFGS(model.parameters(), history_size=10)
def closure():
optimizer.zero_grad()
y_pred = model(x_train_tensor)
loss = criterion(y_pred, y_train_tensor)
#print('loss:', loss.item())
loss.backward()
return loss
max_iter = 10
for i in range(max_iter):
loss = optimizer.step(closure)
params = list(model.parameters())
w_torch_bfgs = params[0][0].detach().numpy() #(D,) vector
print("parameters from sklearn {}".format(w_mle_sklearn))
print("parameters from torch-bfgs {}".format(w_torch_bfgs))
There is also a version of LBFGS in TF
In [0]:
import tensorflow_datasets as tfds
def make_batcher(batch_size, X, y):
def get_batches():
# Convert numpy arrays to tfds
ds = tf.data.Dataset.from_tensor_slices({"X": X, "y": y})
ds = ds.batch(batch_size)
# convert tfds into an iterable of dict of NumPy arrays
return tfds.as_numpy(ds)
return get_batches
batcher = make_batcher(2, X_train, y_train)
for epoch in range(2):
print('epoch {}'.format(epoch))
for batch in batcher():
x, y = batch["X"], batch["y"]
#print(x.shape)
In [0]:
def sgd(params, loss_fn, grad_loss_fn, get_batches_as_dict, max_epochs, lr):
print_every = max(1, int(0.1*max_epochs))
for epoch in range(max_epochs):
epoch_loss = 0.0
for batch_dict in get_batches_as_dict():
x, y = batch_dict["X"], batch_dict["y"]
batch = (x, y)
batch_grad = grad_loss_fn(params, batch)
params = params - lr*batch_grad
batch_loss = loss_fn(params, batch) # Average loss within this batch
epoch_loss += batch_loss
if epoch % print_every == 0:
print('Epoch {}, Loss {}'.format(epoch, epoch_loss))
return params,
In [0]:
set_seed(0)
D = X_train.shape[1]
w_init = onp.random.randn(D)
def training_loss2(w):
l = NLL(w, (X_train, y_train))
return onp.float64(l)
def training_grad2(w):
g = NLL_grad(w, (X_train, y_train))
return onp.asarray(g, dtype=onp.float64)
max_epochs = 5
lr = 0.1
batch_size = 10
batcher = make_batcher(batch_size, X_train, y_train)
w_mle_sgd = sgd(w_init, NLL, NLL_grad, batcher, max_epochs, lr)
print(w_mle_sgd)
JAX has a small optimization library focused on stochastic first-order optimizers. Every optimizer is modeled as an (init_fun
, update_fun
, get_params
) triple of functions. The init_fun
is used to initialize the optimizer state, which could include things like momentum variables, and the update_fun
accepts a gradient and an optimizer state to produce a new optimizer state. The get_params
function extracts the current iterate (i.e. the current parameters) from the optimizer state. The parameters being optimized can be ndarrays or arbitrarily-nested list/tuple/dict structures, so you can store your parameters however you’d like.
Below we show how to reproduce our numpy code using this library.
In [0]:
# Version that uses JAX optimization library
#@jit
def sgd_jax(params, loss_fn, get_batches, max_epochs, opt_init, opt_update, get_params):
loss_history = []
opt_state = opt_init(params)
#@jit
def update(i, opt_state, batch):
params = get_params(opt_state)
g = grad(loss_fn)(params, batch)
return opt_update(i, g, opt_state)
print_every = max(1, int(0.1*max_epochs))
total_steps = 0
for epoch in range(max_epochs):
epoch_loss = 0.0
for batch_dict in get_batches():
X, y = batch_dict["X"], batch_dict["y"]
batch = (X, y)
total_steps += 1
opt_state = update(total_steps, opt_state, batch)
params = get_params(opt_state)
train_loss = onp.float(loss_fn(params, batch))
loss_history.append(train_loss)
if epoch % print_every == 0:
print('Epoch {}, train NLL {}'.format(epoch, train_loss))
return params, loss_history
In [0]:
b=list(batcher())
X, y = b[0]["X"], b[0]["y"]
X.shape
batch = (X, y)
params= w_init
onp.float(NLL(params, batch))
g = grad(NLL)(params, batch)
In [0]:
# JAX with constant LR should match our minimal version of SGD
schedule = optimizers.constant(step_size=lr)
opt_init, opt_update, get_params = optimizers.sgd(step_size=schedule)
w_mle_sgd2, history = sgd_jax(w_init, NLL, batcher, max_epochs,
opt_init, opt_update, get_params)
print(w_mle_sgd2)
print(history)
In [0]: