In [ ]:
##### Copyright 2020 Google LLC
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Introduction

In this notebook, we show how to quantize a model using AutoQKeras.

As usual, let's first make sure we are using Python 3.


In [ ]:
import sys
print(sys.version)

Now, let's load some packages we will need to run AutoQKeras.


In [ ]:
import warnings
warnings.filterwarnings("ignore")

import json
import pprint
import numpy as np
import six
import tempfile
import tensorflow.compat.v2 as tf
# V2 Behavior is necessary to use TF2 APIs before TF2 is default TF version internally.
tf.enable_v2_behavior()
from tensorflow.keras.optimizers import *

from qkeras.autoqkeras import *
from qkeras import *
from qkeras.utils import model_quantize
from qkeras.qtools import run_qtools
from qkeras.qtools import settings as qtools_settings

from tensorflow.keras.utils import to_categorical
import tensorflow_datasets as tfds

print("using tensorflow", tf.__version__)

Let's define get_data and get_model as you may not have stand alone access to examples directory inside autoqkeras.


In [ ]:
def get_data(dataset_name, fast=False):
  """Returns dataset from tfds."""
  ds_train = tfds.load(name=dataset_name, split="train", batch_size=-1)
  ds_test = tfds.load(name=dataset_name, split="test", batch_size=-1)

  dataset = tfds.as_numpy(ds_train)
  x_train, y_train = dataset["image"].astype(np.float32), dataset["label"]

  dataset = tfds.as_numpy(ds_test)
  x_test, y_test = dataset["image"].astype(np.float32), dataset["label"]

  if len(x_train.shape) == 3:
    x_train = x_train.reshape(x_train.shape + (1,))
    x_test = x_test.reshape(x_test.shape + (1,))

  x_train /= 256.0
  x_test /= 256.0

  x_mean = np.mean(x_train, axis=0)

  x_train -= x_mean
  x_test -= x_mean

  nb_classes = np.max(y_train) + 1
  y_train = to_categorical(y_train, nb_classes)
  y_test = to_categorical(y_test, nb_classes)

  print(x_train.shape[0], "train samples")
  print(x_test.shape[0], "test samples")
  return (x_train, y_train), (x_test, y_test)

In [ ]:
from tensorflow.keras.initializers import *
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import *

class ConvBlockNetwork(object):
  """Creates Convolutional block type of network."""

  def __init__(
      self,
      shape,
      nb_classes,
      kernel_size,
      filters,
      dropout_rate=0.0,
      with_maxpooling=True,
      with_batchnorm=True,
      kernel_initializer="he_normal",
      bias_initializer="zeros",
      use_separable=False,
      use_xnornet_trick=False,
      all_conv=False
  ):
    """Creates class.

    Args:
      shape: shape of inputs.
      nb_classes: number of output classes.
      kernel_size: kernel_size of network.
      filters: sizes of filters (if entry is a list, we create a block).
      dropout_rate: dropout rate if > 0.
      with_maxpooling: if true, use maxpooling.
      with_batchnorm: with BatchNormalization.
      kernel_initializer: kernel_initializer.
      bias_initializer: bias and beta initializer.
      use_separable: if "dsp", do conv's 1x3 + 3x1. If "mobilenet",
        use MobileNet separable convolution. If False or "none", perform single
        conv layer.
      use_xnornet_trick: use bn+act after max pool to enable binary
        to avoid saturation to largest value.
      all_conv: if true, implements all convolutional network.
    """
    self.shape = shape
    self.nb_classes = nb_classes
    self.kernel_size = kernel_size
    self.filters = filters
    self.dropout_rate = dropout_rate
    self.with_maxpooling = with_maxpooling
    self.with_batchnorm = with_batchnorm
    self.kernel_initializer = kernel_initializer
    self.bias_initializer = bias_initializer
    self.use_separable = use_separable
    self.use_xnornet_trick = use_xnornet_trick
    self.all_conv = all_conv

  def build(self):
    """Builds model."""
    x = x_in = Input(self.shape, name="input")
    for i in range(len(self.filters)):
      if len(self.filters) > 1:
        name_suffix_list = [str(i)]
      else:
        name_suffix_list = []
      if not isinstance(self.filters[i], list):
        filters = [self.filters[i]]
      else:
        filters = self.filters[i]
      for j in range(len(filters)):
        if len(filters) > 1:
          name_suffix = "_".join(name_suffix_list + [str(j)])
        else:
          name_suffix = "_".join(name_suffix_list)
        if self.use_separable == "dsp":
          kernels = [(1, self.kernel_size), (self.kernel_size, 1)]
        else:
          kernels = [(self.kernel_size, self.kernel_size)]
        for k, kernel in enumerate(kernels):
          strides = 1
          if (
              not self.with_maxpooling and j == len(filters)-1 and
              k == len(kernels)-1
          ):
            strides = 2
          if self.use_separable == "dsp":
            kernel_suffix = (
                "".join([str(k) for k in kernel]) + "_" + name_suffix)
          elif self.use_separable == "mobilenet":
            depth_suffix = (
                "".join([str(k) for k in kernel]) + "_" + name_suffix)
            kernel_suffix = "11_" + name_suffix
          else:
            kernel_suffix = name_suffix
          if self.use_separable == "mobilenet":
            x = DepthwiseConv2D(
                kernel,
                padding="same", strides=strides,
                use_bias=False,
                name="conv2d_dw_" + depth_suffix)(x)
            if self.with_batchnorm:
              x = BatchNormalization(name="conv2d_dw_bn_" + depth_suffix)(x)
            x = Activation("relu", name="conv2d_dw_act_" + depth_suffix)(x)
            kernel = (1, 1)
            strides = 1
          x = Conv2D(
              filters[j], kernel,
              strides=strides, use_bias=not self.with_batchnorm,
              padding="same",
              kernel_initializer=self.kernel_initializer,
              bias_initializer=self.bias_initializer,
              name="conv2d_" + kernel_suffix)(x)
          if not (
              self.with_maxpooling and self.use_xnornet_trick and
              j == len(filters)-1 and k == len(kernels)-1
          ):
            if self.with_batchnorm:
              x = BatchNormalization(
                  beta_initializer=self.bias_initializer,
                  name="bn_" + kernel_suffix)(x)
            x = Activation("relu", name="act_" + kernel_suffix)(x)
      if self.with_maxpooling:
        x = MaxPooling2D(2, 2, name="mp_" + name_suffix)(x)
        # this is a trick from xnornet to enable full binary or ternary
        # networks to be after maxpooling.
        if self.use_xnornet_trick:
          x = BatchNormalization(
              beta_initializer=self.bias_initializer,
              name="mp_bn_" + name_suffix)(x)
          x = Activation("relu", name="mp_act_" + name_suffix)(x)
      if self.dropout_rate > 0:
        x = Dropout(self.dropout_rate, name="drop_" + name_suffix)(x)

    if not self.all_conv:
      x = Flatten(name="flatten")(x)
      x = Dense(
          self.nb_classes,
          kernel_initializer=self.kernel_initializer,
          bias_initializer=self.bias_initializer,
          name="dense")(x)
      x = Activation("softmax", name="softmax")(x)
    else:
      x = Conv2D(
          self.nb_classes, 1, strides=1, padding="same",
          kernel_initializer=self.kernel_initializer,
          bias_initializer=self.bias_initializer,
          name="dense")(x)
      x = Activation("softmax", name="softmax")(x)
      x = Flatten(name="flatten")(x)

    model = Model(inputs=[x_in], outputs=[x])

    return model


def get_model(dataset):
  """Returns a model for the demo of AutoQKeras."""
  if dataset == "mnist":
    model = ConvBlockNetwork(
        shape=(28, 28, 1),
        nb_classes=10,
        kernel_size=3,
        filters=[16, 32, 48, 64, 128],
        dropout_rate=0.2,
        with_maxpooling=False,
        with_batchnorm=True,
        kernel_initializer="he_uniform",
        bias_initializer="zeros",
    ).build()

  elif dataset == "fashion_mnist":
    model = ConvBlockNetwork(
        shape=(28, 28, 1),
        nb_classes=10,
        kernel_size=3,
        filters=[16, [32]*3, [64]*3],
        dropout_rate=0.2,
        with_maxpooling=True,
        with_batchnorm=True,
        use_separable="mobilenet",
        kernel_initializer="he_uniform",
        bias_initializer="zeros",
        use_xnornet_trick=True
    ).build()

  elif dataset == "cifar10":
    model = ConvBlockNetwork(
        shape=(32, 32, 3),
        nb_classes=10,
        kernel_size=3,
        filters=[16, [32]*3, [64]*3, [128]*3],
        dropout_rate=0.2,
        with_maxpooling=True,
        with_batchnorm=True,
        use_separable="mobilenet",
        kernel_initializer="he_uniform",
        bias_initializer="zeros",
        use_xnornet_trick=True
    ).build()

  elif dataset == "cifar100":
    model = ConvBlockNetwork(
        shape=(32, 32, 3),
        nb_classes=100,
        kernel_size=3,
        filters=[16, [32]*3, [64]*3, [128]*3, [256]*3],
        dropout_rate=0.2,
        with_maxpooling=True,
        with_batchnorm=True,
        use_separable="mobilenet",
        kernel_initializer="he_uniform",
        bias_initializer="zeros",
        use_xnornet_trick=True
    ).build()

  model.summary()

  return model

AutoQKeras has some examples on how to run with mnist, fashion_mnist, cifar10 and cifar100.


In [ ]:
DATASET = "mnist"
(x_train, y_train), (x_test, y_test) = get_data(DATASET)

Before we create the model, let's see if we can perform distributed training.


In [ ]:
physical_devices = tf.config.list_physical_devices()
for d in physical_devices:
  print(d)

In [ ]:
has_tpus = np.any([d.device_type == "TPU" for d in physical_devices])

if has_tpus:
  TPU_WORKER = 'local'

  resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
      tpu=TPU_WORKER, job_name='tpu_worker')
  if TPU_WORKER != 'local':
    tf.config.experimental_connect_to_cluster(resolver, protocol='grpc+loas')
  tf.tpu.experimental.initialize_tpu_system(resolver)
  strategy = tf.distribute.experimental.TPUStrategy(resolver)
  print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

  cur_strategy = strategy
else:
  cur_strategy = tf.distribute.get_strategy()

Now we can create the model with the distributed strategy in place if TPUs are available. We have some test models that we can use, or you can build your own models.


In [ ]:
with cur_strategy.scope():
  model = get_model(DATASET)
  custom_objects = {}

Let's see the accuracy on a unquantized model.


In [ ]:
with cur_strategy.scope():
  optimizer = Adam(lr=0.02)
  model.compile(optimizer=optimizer, loss="categorical_crossentropy", metrics=["acc"])
  model.fit(x_train, y_train, epochs=10, batch_size=2048, steps_per_epoch=29, validation_data=(x_test, y_test))

For mnist, we should get 99% validation accuracy, and for fashion_mnist, we should get around 86% of validation accuracy. Let's get a metric for high-level estimation of energy of this model.


In [ ]:
reference_internal = "fp32"
  reference_accumulator = "fp32"

  q = run_qtools.QTools(
      model,
      # energy calculation using a given process
      # "horowitz" refers to 45nm process published at
      # M. Horowitz, "1.1 Computing's energy problem (and what we can do about
      # it), "2014 IEEE International Solid-State Circuits Conference Digest of
      # Technical Papers (ISSCC), San Francisco, CA, 2014, pp. 10-14, 
      # doi: 10.1109/ISSCC.2014.6757323.
      process="horowitz",
      # quantizers for model input
      source_quantizers=[quantized_bits(8, 0, 1)],
      is_inference=False,
      # absolute path (including filename) of the model weights
      # in the future, we will attempt to optimize the power model
      # by using weight information, although it can be used to further
      # optimize QBatchNormalization.
      weights_path=None,
      # keras_quantizer to quantize weight/bias in un-quantized keras layers
      keras_quantizer=reference_internal,
      # keras_quantizer to quantize MAC in un-quantized keras layers
      keras_accumulator=reference_accumulator,
      # whether calculate baseline energy
      for_reference=True)
  
# caculate energy of the derived data type map.
energy_dict = q.pe(
    # whether to store parameters in dram, sram, or fixed
    weights_on_memory="sram",
    # store activations in dram or sram
    activations_on_memory="sram",
    # minimum sram size in number of bits. Let's assume a 16MB SRAM.
    min_sram_size=8*16*1024*1024,
    # whether load data from dram to sram (consider sram as a cache
    # for dram. If false, we will assume data will be already in SRAM
    rd_wr_on_io=False)

# get stats of energy distribution in each layer
energy_profile = q.extract_energy_profile(
    qtools_settings.cfg.include_energy, energy_dict)
# extract sum of energy of each layer according to the rule specified in
# qtools_settings.cfg.include_energy
total_energy = q.extract_energy_sum(
    qtools_settings.cfg.include_energy, energy_dict)

pprint.pprint(energy_profile)
print()
print("Total energy: {:.2f} uJ".format(total_energy / 1000000.0))

During the computation, we had a dictionary that outlines the energy per layer (energy_profile), and total energy (total_energy). The reader should remember that energy_profile may need additional filtering as implementations will fuse some layers. When we compute the total_energy, we consider an approximation that some layers will be fused to compute the final energy number. For example, a convolution layer followed by an activation layer will be fused into a single layer so that the output of the convolution layer is not used.

You have to remember that our high-level model for energy has several assumptions:

The energy of a layer is estimated as energy(layer) = energy(input) + energy(parameters) + energy(MAC) + energy(output).

1) Reading inputs, parameters and outputs consider only compulsory accesses, i.e. first access to the data, which is independent of the hardware architecture. If you remember The 3 C's of Caches (https://courses.cs.washington.edu/courses/cse410/99au/lectures/Lecture-10-18/tsld035.htm) other types of accesses will depend on the accelerator architecture.

2) For the multiply-and-add (MAC) energy estimation, we only consider the energy to compute the MAC, but not any other type energy. For example, in a real accelerator, you have registers, glue logic, pipeline logic that will affect the overall energy profile of the device.

Although this model is simple and provides an initial estimate on what to expect, it has high-variance with respect to actual energy numbers you will find in practice, especially with respect to different architectural implementations.

We assume that the real energy Energy(layer) is a linear combination of the high-level energy model, i.e.Energy(layer) = k1 * energy(layer) + k2, where k1 and k2 are constants that depend on the architecture of the accelerator. One can think of k1 as the factor that accounts for the additional storage to keep the model running, and k2 as the additional always on logic that is required to perform the operations. If we compare the energy of two implementations with different quantizations of the same layer, let's say layer1 and layer2, Energy(layer1) > Energy(layer2) holds true iff energy(layer1) > energy(layer2) for the same architecture, but for different architectures, this will not be true in general.

Despite its limitations to predict a single energy number, this model is quite good to compare the energy of two different models, or different types of quantizations, when we restrict it to a single architecture, and that's how we use it here.

Quantizing a Model With AutoQKeras

To quantize this model with AutoQKeras, we need to define the quantization for kernels, biases and activations; forgiving factors and quantization strategy.

Below we define which quantizers are allowed for kernel, bias, activations and linear. Linear is a proxy that we use to capture Activation("linear") to apply quantization without applying a non-linear operation. In some networks, we found that this trick may be necessary to better represent the quantization space.


In [ ]:
quantization_config = {
        "kernel": {
                "binary": 1,
                "stochastic_binary": 1,
                "ternary": 2,
                "stochastic_ternary": 2,
                "quantized_bits(2,1,1,alpha=1.0)": 2,
                "quantized_bits(4,0,1,alpha=1.0)": 4,
                "quantized_bits(8,0,1,alpha=1.0)": 8,
                "quantized_po2(4,1)": 4
        },
        "bias": {
                "quantized_bits(4,0,1)": 4,
                "quantized_bits(8,3,1)": 8,
                "quantized_po2(4,8)": 4
        },
        "activation": {
                "binary": 1,
                "ternary": 2,
                "quantized_relu_po2(4,4)": 4,
                "quantized_relu(3,1)": 3,
                "quantized_relu(4,2)": 4,
                "quantized_relu(8,2)": 8,
                "quantized_relu(8,4)": 8,
                "quantized_relu(16,8)": 16
        },
        "linear": {
                "binary": 1,
                "ternary": 2,
                "quantized_bits(4,1)": 4,
                "quantized_bits(8,2)": 8,
                "quantized_bits(16,10)": 16
        }
}

Now let's define how to apply quantization. In the simplest form, we specify how many bits for kernels, biases and activations by layer types. Note that the entry BatchNormalization needs to be specified here, as we only quantize layer types specified by these patterns. For example, a Flatten layer is not quantized as it does not change the data type of its inputs.


In [ ]:
limit = {
    "Dense": [8, 8, 4],
    "Conv2D": [4, 8, 4],
    "DepthwiseConv2D": [4, 8, 4],
    "Activation": [4],
    "BatchNormalization": []
}

Here, we are specifying that we want to use at most 4 bits for weights and activations, and at most 8 bits for biases in convolutional and depthwise convolutions, but we allow up to 8 bits for kernels in dense layers.

Let's define now the forgiving factor. We will consider energy minimization as a goal as follows. Here, we are saying that we allow 8% reduction in accuracy for a 2x reduction in energy, both reference and trials have parameters and activations on SRAM, both reference model and quantization trials do not read/write from DRAM on I/O operations, and we should consider both experiments to use SRAMs with minimum tensor sizes (commonly called distributed SRAM implementation).

We also need to specify the quantizers for the inputs. In this case, we want to use int8 as source quantizers. Other possible types are int16, int32, fp16 or fp32, besides QKeras quantizer types.

Finally, to be fair, we want to compare our quantization against fixed-point 8-bit inputs, outputs, activations, weights and biases, and 32-bit accumulators.

Remember that a forgiving factor forgives a drop in a metric such as accuracy if the gains of the model are much bigger than the drop. For example, it corresponds to the sentence we allow $\tt{delta}\%$ reduction in accuracy if the quantized model has $\tt{rate} \times$ smaller energy than the original model, being a multiplicative factor to the metric. It is computed by $1 + \tt{delta} \times \log_{\tt{rate}}(\tt{stress} \times \tt{reference\_cost} / \tt{trial\_cost})$.


In [ ]:
goal = {
    "type": "energy",
    "params": {
        "delta_p": 8.0,
        "delta_n": 8.0,
        "rate": 2.0,
        "stress": 1.0,
        "process": "horowitz",
        "parameters_on_memory": ["sram", "sram"],
        "activations_on_memory": ["sram", "sram"],
        "rd_wr_on_io": [False, False],
        "min_sram_size": [0, 0],
        "source_quantizers": ["int8"],
        "reference_internal": "int8",
        "reference_accumulator": "int32"
        }
}

There are a few more things we need to define. Let's bundle them on a dictionary and pass them to AutoQKeras. We will try a maximum of 10 trials (max_trials) just to limit the time we will spend finding the best quantization here. Please note that this parameter is not valid if you are running in hyperband mode.

output_dir is the directory where we will store our results. Since we are running on a colab, we will let tempfile chooce a directory for us.

learning_rate_optimizer allows AutoQKeras to change the optimization function and the learning_rate to try to improve the quantization results. Since it is still experimental, it may be the case that in some cases it will get worse results.

Because we are tuning filters as well, we should set transfer_weights to False as the trainable parameters will have different shapes.

In AutoQKeras we have three modes of operation: random, bayesian and hyperband. I recommend the user to refer to KerasTuner (https://keras-team.github.io/keras-tuner/) for a complete description of them.

tune_filters can be set to layer, block or none. If tune_filters is block, we change the filters by the same amount for all layers being quantized in the trial. If tune_filters is layer, we will possibly change the number of filters for each layer independently. Finally, if tune_filters is none, we will not perform filter tuning.

Together with tune_filters, tune_filter_exceptions allows the user to specify by a regular expression which filters we should not perform filter tuning, which is especially good for the last layers of the network.

Filter tuning is a very important feature of AutoQKeras. When we deep quantize a model, we may need less or more filters for each layer (and you can guess we do not know a priori how many filters we will need for each layer). Let me give you a rationale behind this.

  • less filters: let us assume we have two set of filter coefficients we want quantize: $[-0.3, 0.2, 0.5, 0.15]$ and $[-0.5, 0.4, 0.1, 0.65]$. If we apply a $\tt{binary}$ quantizer with $\tt{scale} = \big\lceil \log_2(\frac{\sum |w|}{N}) \big\rceil$, where $w$ are the filter coefficients and $N$ is the number of coefficients, we will end up with the same filter $\tt{binary}([-0.3, 0.2, 0.5, 0.15]) = \tt{binary}([-0.5, 0.4, 0.1, 0.65]) = [-1,1,1,1] \times 0.5$. In this case we are assuming the $\tt{scale}$ is a power-of-2 number so that it can be efficiently implemented by a shift operation;

  • more filters: it is clear that quantization will drop information (just look at the example above) and deep quantization will drop more information, so to recover some of the boundary regions in layers that perform feature extraction, we may need to add more filters to the layer when we quantize it.

We do not want to quantize the softmax layer, which is the last layer of the network. In AutoQKeras, you can specify the indexes that you want to perform quantization by specifying the corresponding index of the layer in Keras, i.e. if you can get the layer as model.layers[i] in Keras, i is the index of the layer.

Finally, for data parallel distributed training, we should pass the strategy in distribution_strategy to KerasTuner.


In [ ]:
run_config = {
  "output_dir": tempfile.mkdtemp(),
  "goal": goal,
  "quantization_config": quantization_config,
  "learning_rate_optimizer": False,
  "transfer_weights": False,
  "mode": "random",
  "seed": 42,
  "limit": limit,
  "tune_filters": "layer",
  "tune_filters_exceptions": "^dense",
  "distribution_strategy": cur_strategy,
  # first layer is input, layer two layers are softmax and flatten
  "layer_indexes": range(1, len(model.layers) - 1),
  "max_trials": 20
}

print("quantizing layers:", [model.layers[i].name for i in run_config["layer_indexes"]])

In [ ]:
autoqk = AutoQKeras(model, metrics=["acc"], custom_objects=custom_objects, **run_config)
autoqk.fit(x_train, y_train, validation_data=(x_test, y_test), batch_size=1024, epochs=20)

Now, let's see which model is the best model we got.


In [ ]:
qmodel = autoqk.get_best_model()
qmodel.save_weights("qmodel.h5")

We got here >90% reduction in energy when compared to 8-bit tensors and 32-bit accumulators. Remember that our original number was 3.3 uJ for fp32. The end model has 11 nJ for the quantized model as opposed to 204 nJ for the 8-bit original quantized model. As these energy numbers are from high-level energy models, you should remember to consider the relations between them, and not the actual numbers.

Let's train this model to see how much accuracy we can get of it.


In [ ]:
qmodel.load_weights("qmodel.h5")
with cur_strategy.scope():
  optimizer = Adam(lr=0.02)
  qmodel.compile(optimizer=optimizer, loss="categorical_crossentropy", metrics=["acc"])
  qmodel.fit(x_train, y_train, epochs=200, batch_size=4096, validation_data=(x_test, y_test))

One of problems of trying to quantize the whole thing in one shot is that we may end up with too many choices to make, which will make the entire search space very high. In order to reduce the search space, AutoQKeras has two methods to enable users to cope with the explosion of choices.

Grouping Layers to Use the Same Choice

In this case, we can provide regular expressions to limit to specify layer names that should be grouped together. In our example, suppose we want to group convolution layers (except the first one) and all activations except the last one to use the same quantization.

For the first convolution layer, we want to limit the quantization types to fewer choices as the input is already an 8-bit number. The last activation will be fed to a feature classifier layer, so we may leave it with more bits. Because our dense is actually a Conv2D operation, we will enable 8-bits for the weights by layer name.

We first need to look at the names of the layers for this.


In [ ]:
pprint.pprint([layer.name for layer in model.layers])

Convolution layers for mnist have names specified as conv2d_[01234]. Activation layers have names specified as act_[01234]. So, we can create the following regular expressions to reduce the search space in our model.

Please note that layer class names always select different quantizers, so the user needs to specify a pattern for layer names if he/she wants to use the same quantization for the group of layers.

You can see here another feature of the limit. You can specify the maximum number of bits, or cherry pick which quantizers you want to try for a specific layer if instead of the maximum number of bits you specify a list of quantizers fron quantization_config.


In [ ]:
limit = {
    "Dense": [8, 8, 4],
    "Conv2D": [4, 8, 4],
    "DepthwiseConv2D": [4, 8, 4],
    "Activation": [4],
    "BatchNormalization": [],

    "^conv2d_0$": [
                   ["binary", "ternary", "quantized_bits(2,1,1,alpha=1.0)"],
                   8, 4
    ],
    "^conv2d_[1234]$": [4, 8, 4],
    "^act_[0123]$": [4],
    "^act_4$": [8],
    "^dense$": [8, 8, 4]
}

In [ ]:
run_config = {
  "output_dir": tempfile.mkdtemp(),
  "goal": goal,
  "quantization_config": quantization_config,
  "learning_rate_optimizer": False,
  "transfer_weights": False,
  "mode": "random",
  "seed": 42,
  "limit": limit,
  "tune_filters": "layer",
  "tune_filters_exceptions": "^dense",
  "distribution_strategy": cur_strategy,
  "layer_indexes": range(1, len(model.layers) - 1),
  "max_trials": 40
}

In [ ]:
autoqk = AutoQKeras(model, metrics=["acc"], custom_objects=custom_objects, **run_config)
autoqk.fit(x_train, y_train, validation_data=(x_test, y_test), batch_size=1024, epochs=20)

Let's see the reduction now.


In [ ]:
qmodel = autoqk.get_best_model()
qmodel.save_weights("qmodel.h5")

Let's train this model for more time to see how much we can get in accuracy.


In [ ]:
qmodel.load_weights("qmodel.h5")
with cur_strategy.scope():
  optimizer = Adam(lr=0.02)
  qmodel.compile(optimizer=optimizer, loss="categorical_crossentropy", metrics=["acc"])
  qmodel.fit(x_train, y_train, epochs=200, batch_size=4096, validation_data=(x_test, y_test))

Quantization by Blocks

In the previous section, we enforced that all decisions were the same in order to reduce the number of options to quantize a model.

Another approach is still to allow models to have each block of layers to makde their own choice, but quantizing the blocks sequentially, either from inputs to outputs, or by quantizing higher energy blocks first.

The rationale for this method is that if we quantize the blocks one by one, and assuming that each block has $N$ choices, and $B$ blocks, we end up trying $N B$ options, instead of $N^B$ choices. The reader should note that this is an approximation as there is no guarantee that we will obtain the best quantization possible.

Should you do sequential from inputs to outputs or starting from the block that has the highest impact?

If you have a network like ResNet, and if you want to do filter tuning, you need to block the layers by the resnet definition of a block, i.e. including full identity or convolutional blocks, and quantize the model from inputs to outputs, so that you can preserve at each stage the number of channels for the residual block.

In order to perform quantization by blocks, you need to specify two other parameters in our run_config. blocks is a list of regular expressions of the groups you want to quantize. If a layer does not match the block pattern, it will not be quantized. schedule_block specifies the mode for block quantization scheduling. It can be sequential or cost if you want to schedule first the blocks by decreasing cost size (energy or bits).

In this model, there are a few optimizations that we perform automatically. First, we dynamically reduce the learning rate of the blocks that we have already quantized as setting them to not-trainable does not seem to work, so we still allow them to train, but at a slower pace. In addition, we try to dynamically adjust the learning rate for the layer we are trying to quantize as opposed to the learning rate of the unquantized layers. Finally, we transfer the weights of the models we have already quantized whenever we can do (if the shapes remain the same).

Regardless on how we schedule the operations, we amortize the nubmer of trials for the cost of the block (energy or bits with respect to the total energy or number of bits of the network).

Instead of invoking AutoQKeras now, we will invoke AutoQKeras scheduler.


In [ ]:
run_config = {
  "output_dir": tempfile.mkdtemp(),
  "goal": goal,
  "quantization_config": quantization_config,
  "learning_rate_optimizer": False,
  "transfer_weights": False,
  "mode": "random",
  "seed": 42,
  "limit": limit,
  "tune_filters": "layer",
  "tune_filters_exceptions": "^dense",
  "distribution_strategy": cur_strategy,
  "layer_indexes": range(1, len(model.layers) - 1),
  "max_trials": 40,

  "blocks": [
    "^.*_0$",
    "^.*_1$",
    "^.*_2$",
    "^.*_3$",
    "^.*_4$",
    "^dense"
  ],
  "schedule_block": "cost"
}

Because specifying regular expressions is error prone, we recommend that you first try to run AutoQKerasScheduler in debug mode to print the blocks.


In [ ]:
pprint.pprint([layer.name for layer in model.layers])
autoqk = AutoQKerasScheduler(model, metrics=["acc"], custom_objects=custom_objects, debug=True, **run_config)
autoqk.fit(x_train, y_train, validation_data=(x_test, y_test), batch_size=1024, epochs=20)

All blocks seem to be fine. Let's find the best quantization now.


In [ ]:
autoqk = AutoQKerasScheduler(model, metrics=["acc"], custom_objects=custom_objects, **run_config)
autoqk.fit(x_train, y_train, validation_data=(x_test, y_test), batch_size=1024, epochs=20)

In [ ]:
qmodel = autoqk.get_best_model()
qmodel.save_weights("qmodel.h5")

In [ ]:
qmodel.load_weights("qmodel.h5")
with cur_strategy.scope():
  optimizer = Adam(lr=0.02)
  qmodel.compile(optimizer=optimizer, loss="categorical_crossentropy", metrics=["acc"])
  qmodel.fit(x_train, y_train, epochs=200, batch_size=4096, validation_data=(x_test, y_test))

Perfect! You have learned how to perform automatic quantization using AutoQKeras with QKeras.


In [ ]: