In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
This colab demonstrates how to:
It is good to get an understanding or quickly try things. However, to run longer training runs, we recommend using the commandline scripts at http://github.com/google-research/big_transfer
In [28]:
!pip install flax
In [0]:
import io
import re
from functools import partial
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import flax
import flax.nn as nn
import flax.optim as optim
import flax.jax_utils as flax_utils
# Assert that GPU is available
assert 'Gpu' in str(jax.devices())
import tensorflow as tf
import tensorflow_datasets as tfds
In [0]:
def fixed_padding(x, kernel_size):
pad_total = kernel_size - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
x = jax.lax.pad(x, 0.0,
((0, 0, 0),
(pad_beg, pad_end, 0), (pad_beg, pad_end, 0),
(0, 0, 0)))
return x
def standardize(x, axis, eps):
x = x - jnp.mean(x, axis=axis, keepdims=True)
x = x / jnp.sqrt(jnp.mean(jnp.square(x), axis=axis, keepdims=True) + eps)
return x
class GroupNorm(nn.Module):
"""Group normalization (arxiv.org/abs/1803.08494)."""
def apply(self, x, num_groups=32):
input_shape = x.shape
group_shape = x.shape[:-1] + (num_groups, x.shape[-1] // num_groups)
x = x.reshape(group_shape)
# Standardize along spatial and group dimensions
x = standardize(x, axis=[1, 2, 4], eps=1e-5)
x = x.reshape(input_shape)
bias_scale_shape = tuple([1, 1, 1] + [input_shape[-1]])
x = x * self.param('scale', bias_scale_shape, nn.initializers.ones)
x = x + self.param('bias', bias_scale_shape, nn.initializers.zeros)
return x
class StdConv(nn.Conv):
def param(self, name, shape, initializer):
param = super().param(name, shape, initializer)
if name == 'kernel':
param = standardize(param, axis=[0, 1, 2], eps=1e-10)
return param
class RootBlock(nn.Module):
def apply(self, x, width):
x = fixed_padding(x, 7)
x = StdConv(x, width, (7, 7), (2, 2),
padding="VALID",
bias=False,
name="conv_root")
x = fixed_padding(x, 3)
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="VALID")
return x
class ResidualUnit(nn.Module):
"""Bottleneck ResNet block."""
def apply(self, x, nout, strides=(1, 1)):
x_shortcut = x
needs_projection = x.shape[-1] != nout * 4 or strides != (1, 1)
group_norm = GroupNorm
conv = StdConv.partial(bias=False)
x = group_norm(x, name="gn1")
x = nn.relu(x)
if needs_projection:
x_shortcut = conv(x, nout * 4, (1, 1), strides, name="conv_proj")
x = conv(x, nout, (1, 1), name="conv1")
x = group_norm(x, name="gn2")
x = nn.relu(x)
x = fixed_padding(x, 3)
x = conv(x, nout, (3, 3), strides, name="conv2", padding='VALID')
x = group_norm(x, name="gn3")
x = nn.relu(x)
x = conv(x, nout * 4, (1, 1), name="conv3")
return x + x_shortcut
class ResidualBlock(nn.Module):
def apply(self, x, block_size, nout, first_stride):
x = ResidualUnit(
x, nout, strides=first_stride,
name="unit01")
for i in range(1, block_size):
x = ResidualUnit(
x, nout, strides=(1, 1),
name=f"unit{i+1:02d}")
return x
class ResNet(nn.Module):
"""ResNetV2."""
def apply(self, x, num_classes=1000,
width_factor=1, num_layers=50):
block_sizes = _block_sizes[num_layers]
width = 64 * width_factor
root_block = RootBlock.partial(width=width)
x = root_block(x, name='root_block')
# Blocks
for i, block_size in enumerate(block_sizes):
x = ResidualBlock(x, block_size, width * 2 ** i,
first_stride=(1, 1) if i == 0 else (2, 2),
name=f"block{i + 1}")
# Pre-head
x = GroupNorm(x, name='norm-pre-head')
x = nn.relu(x)
x = jnp.mean(x, axis=(1, 2))
# Head
x = nn.Dense(x, num_classes, name="conv_head",
kernel_init=nn.initializers.zeros)
return x.astype(jnp.float32)
_block_sizes = {
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3],
}
def transform_params(params, params_tf, num_classes, init_head=False):
# BiT and JAX models have different naming conventions, so we need to
# properly map TF weights to JAX weights
params['root_block']['conv_root']['kernel'] = (
params_tf['resnet/root_block/standardized_conv2d/kernel'])
for block in ['block1', 'block2', 'block3', 'block4']:
units = set([re.findall(r'unit\d+', p)[0] for p in params_tf.keys()
if p.find(block) >= 0])
for unit in units:
for i, group in enumerate(['a', 'b', 'c']):
params[block][unit][f'conv{i+1}']['kernel'] = (
params_tf[f'resnet/{block}/{unit}/{group}/'
'standardized_conv2d/kernel'])
params[block][unit][f'gn{i+1}']['bias'] = (
params_tf[f'resnet/{block}/{unit}/{group}/'
'group_norm/beta'][None, None, None])
params[block][unit][f'gn{i+1}']['scale'] = (
params_tf[f'resnet/{block}/{unit}/{group}/'
'group_norm/gamma'][None, None, None])
projs = [p for p in params_tf.keys()
if p.find(f'{block}/{unit}/a/proj') >= 0]
assert len(projs) <= 1
if projs:
params[block][unit]['conv_proj']['kernel'] = params_tf[projs[0]]
params['norm-pre-head']['bias'] = (
params_tf['resnet/group_norm/beta'][None, None, None])
params['norm-pre-head']['scale'] = (
params_tf['resnet/group_norm/gamma'][None, None, None])
if init_head:
params['conv_head']['kernel'] = params_tf['resnet/head/conv2d/kernel'][0, 0]
params['conv_head']['bias'] = params_tf['resnet/head/conv2d/bias']
else:
params['conv_head']['kernel'] = np.zeros(
(params['conv_head']['kernel'].shape[0], num_classes), dtype=np.float32)
params['conv_head']['bias'] = np.zeros(num_classes, dtype=np.float32)
In [0]:
with tf.io.gfile.GFile('gs://bit_models/BiT-M-R50x1-CIFAR10.npz', 'rb') as f:
params_tf = np.load(f)
params_tf = dict(zip(params_tf.keys(), params_tf.values()))
for k in params_tf:
params_tf[k] = jnp.array(params_tf[k])
ResNet_cifar10 = ResNet.partial(num_classes=10)
def resnet_fn(params, images):
return ResNet_cifar10.partial(num_classes=10).call(params, images)
resnet_init = ResNet_cifar10.init_by_shape
_, params = resnet_init(jax.random.PRNGKey(0), [([1, 224, 224, 3], jnp.float32)])
transform_params(params, params_tf, 10, init_head=True)
In [0]:
data_builder = tfds.builder('cifar10')
data_builder.download_and_prepare()
def _pp(data):
im = data['image']
im = tf.image.resize(im, [128, 128])
im = (im - 127.5) / 127.5
data['image'] = im
return {'image': data['image'], 'label': data['label']}
data = data_builder.as_dataset(split='test')
data = data.map(_pp)
data = data.batch(100)
data_iter = data.as_numpy_iterator()
In [33]:
correct, n = 0, 0
for batch in data_iter:
preds = resnet_fn(params, batch['image'])
correct += np.sum(np.argmax(preds, axis=1) == batch['label'])
n += len(preds)
print(f"CIFAR-10 accuracy of BiT-M-R50x1: {correct / n:0.3%}")
In [0]:
data_builder = tfds.builder('cifar100')
data_builder.download_and_prepare()
def get_data(split, repeats, batch_size, images_per_class, shuffle_buffer):
data = data_builder.as_dataset(split=split)
if split == 'train':
data = data.batch(50000)
data = data.as_numpy_iterator().next()
np.random.seed(0)
indices = [idx
for cls in range(100)
for idx in np.random.choice(np.where(data['label'] == cls)[0],
images_per_class,
replace=False)]
data = {'image': data['image'][indices],
'label': data['label'][indices]}
data = tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(data['image']),
tf.data.Dataset.from_tensor_slices(data['label'])))
data = data.map(lambda x, y: {'image': x, 'label': y})
else:
data = data.map(lambda d: {'image': d['image'], 'label': d['label']})
def _pp(data):
im = data['image']
if split == 'train':
im = tf.image.resize(im, [160, 160])
im = tf.image.random_crop(im, [128, 128, 3])
im = tf.image.flip_left_right(im)
else:
im = tf.image.resize(im, [128, 128])
im = (im - 127.5) / 127.5
data['image'] = im
data['label'] = tf.one_hot(data['label'], 100)
return {'image': data['image'], 'label': data['label']}
data = data.repeat(repeats)
data = data.shuffle(shuffle_buffer)
data = data.map(_pp)
return data.batch(batch_size)
data_train = get_data(split='train', repeats=None, images_per_class=5,
batch_size=64, shuffle_buffer=500)
data_test = get_data(split='test', repeats=1, images_per_class=None,
batch_size=250, shuffle_buffer=1)
In [0]:
@jax.jit
def resnet_fn(params, images):
return ResNet.partial(num_classes=100).call(params, images)
def cross_entropy_loss(*, logits, labels):
logp = jax.nn.log_softmax(logits)
return -jnp.mean(jnp.sum(logp * labels, axis=1))
def loss_fn(params, images, labels):
logits = resnet_fn(params, images)
return cross_entropy_loss(logits=logits, labels=labels)
@jax.jit
def update_fn(opt, lr, images, labels):
l, g = jax.value_and_grad(loss_fn)(opt.target, images, labels)
opt = opt.apply_gradient(g, learning_rate=lr)
return opt, l
with tf.io.gfile.GFile('gs://bit_models/BiT-M-R50x1.npz', 'rb') as f:
params_tf = np.load(f)
params_tf = dict(zip(params_tf.keys(), params_tf.values()))
resnet_init = ResNet.partial(num_classes=100).init_by_shape
_, params = resnet_init(jax.random.PRNGKey(0), [([1, 224, 224, 3], jnp.float32)])
transform_params(params, params_tf, 100, init_head=False)
In [0]:
def get_lr(step):
lr = 0.003
if step < 100:
return lr * (step / 100)
else:
for s in [200, 300, 400]:
if s < step:
lr /= 10
return lr
In [55]:
opt = optim.Momentum(beta=0.9).create(params)
for step, batch in zip(range(500), data_train.as_numpy_iterator()):
opt, loss_value = update_fn(
opt, get_lr(step), batch["image"], batch["label"])
if opt.state.step % 100 == 0:
acc = np.mean([c for test_batch in data_test.as_numpy_iterator()
for c in (np.argmax(test_batch['label'], axis=1) ==
np.argmax(resnet_fn(opt.target, test_batch['image']), axis=1))])
print(f"Step: {opt.state.step}, Test accuracy: {acc:0.3%}")
In [0]: