In [ ]:
##### Copyright 2020 Google LLC
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

QKeras Lab Book

QKeras is a quantization extension to Keras that provides drop-in replacement for some of the Keras layers, especially the ones that creates parameters and activation layers, and perform arithmetic operations, so that we can quickly create a deep quantized version of Keras network.

According to Tensorflow documentation, Keras is a high-level API to build and train deep learning models. It's used for fast prototyping, advanced research, and production, with three key advantages:

  • User friendly
    Keras has a simple, consistent interface optimized for common use cases. It provides clear and actionable feedback for user errors.

  • Modular and composable
    Keras models are made by connecting configurable building blocks together, with few restrictions.

  • Easy to extend
    Write custom building blocks to express new ideas for research. Create new layers, loss functions, and develop state-of-the-art models.

QKeras is being designed to extend the functionality of Keras using Keras' design principle, i.e. being user friendly, modular and extensible, adding to it being "minimally intrusive" of Keras native functionality.

QKeras has been implemented based on the work of "B.Moons et al. - Minimum Energy Quantized Neural Networks" , Asilomar Conference on Signals, Systems and Computers, 2017 and “Zhou, S. et al. DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients,” but the framework should be easily extensible. The original code from QNN can be found below.

https://github.com/BertMoons/QuantizedNeuralNetworks-Keras-Tensorflow

QKeras extends QNN by providing a richer set of layers (including SeparableConv2D, DepthwiseConv2D, ternary and stochastic ternary quantizations), besides some functions to aid the estimation for the accumulators and conversion between non-quantized to quantized networks. Finally, our main goal is easy of use, so we attempt to make QKeras layers a true drop-in replacement for Keras, so that users can easily exchange non-quantized layers by quantized ones.

Layers Implemented in QKeras

The following layers have been implemented in QKeras.

  • QDense

  • QConv1D

  • QConv2D

  • QDepthwiseConv2D

  • QSeparableConv2D (depthwise + pointwise expanded, extended from MobileNet SeparableConv2D implementation)

  • QActivation

  • QAveragePooling2D (in fact, a AveragePooling2D stacked with a QActivation layer for quantization of the result, so this layer does not exist)

  • QBatchNormalization

  • QOctaveConv2D

It is worth noting that not all functionality is safe at this time to be used with other high-level operations, such as with layer wrappers. For example, Bidirectional layer wrappers are used with RNNs. If this is required, we encourage users to use quantization functions invoked as strings instead of the actual functions as a way through this, but we may change that implementation in the future.

QSeparableConv2D is implemented as a depthwise + pointwise quantized expansions, which is extended from the SeparableConv2D implementation of MobileNet. With the exception of QBatchNormalization, if quantizers are not specified, no quantization is applied to the layer and it ends up behaving like the orgininal unquantized layers. On the other hand, QBatchNormalization has been implemented differently as if the user does not specify any quantizers as parameters, it uses a set up that has worked best when attempting to implement quantization efficiently in hardware and software, i.e. gamma and variance with po2 quantizers (as they become shift registers in an implementation, and with further constraining variance po2 quantizer to use quadratic approximation as we take the square root of the variance to obtain the standard deviation), beta using po2 quantizer to maintain the dynamic range aspect of the center parameter, and mean remaining unquantized, as it inherits the properties of the previous layer.

Activation has been migrated to QActivation although it QKeras also recognizes activation parameter used in convolutional and dense layers.

We have improved the setup of quantization as convolution, dense and batch normalization layers now notify the quantizers when the quantizers are used as internal parameters, so the user does not need to worry about setting up options that work best in weights and bias like alpha and use_stochastic_rounding (although users may override the automatic setup).

Finally, in the current version, we have eliminated the need to set up the range of the quantizers like kernel_range in QDense. This is automatically computed internally at this point. Although we kept the parameters for backward compatibility, these parameters will be removed in the future.

Activation Layers and Quantizers Implemented in QKeras

Quantizers and activation layers are treated interchangingly in QKeras.

The list of quantizers and its parameters is listed below.

  • smooth_sigmoid(x)

  • hard_sigmoid(x)

  • binary_sigmoid(x)

  • smooth_tanh(x)

  • hard_tanh(x)

  • binary_tanh(x)

  • quantized_bits(bits=8, integer=0, symmetric=0, keep_negative=1, alpha=None, use_stochastic_rouding=False)(x)

  • bernoulli(alpha=None, temperature=6.0, use_real_sigmoid=True)(x)

  • stochastic_ternary(alpha=None, threshold=None, temperature=8.0, use_real_sigmoid=True)(x)

  • ternary(alpha=None, threshold=None, use_stochastic_rounding=False)(x)

  • stochastic_binary(alpha=None, temperature=6.0, use_real_sigmoid=True)(x)

  • binary(use_01=False, alpha=None, use_stochastic_rounding=False)(x)

  • quantized_relu(bits=8, integer=0, use_sigmoid=0, use_stochastic_rounding=False)(x)

  • quantized_ulaw(bits=8, integer=0, symmetric=0, u=255.0)(x)

  • quantized_tanh(bits=8, integer=0, symmetric=0, use_stochastic_rounding=False)(x)

  • quantized_po2(bits=8, max_value=None, use_stochastic_rounding=False, quadratic_approximation=False)(x)

  • quantized_relu_po2(bits=8, max_value=None, use_stochastic_rounding=False, quadratic_approximation=False)(x)

The stochastic_* functions and bernoulli rely on stochastic versions of the activation functions, so they are best suited for weights and biases. They draw a random number with uniform distribution from sigmoid of the input x, and result is based on the expected value of the activation function. Please refer to the papers if you want to understand the underlying theory, or the documentation in qkeras/quantizers.py. The parameter temperature determines how steep the sigmoid function will behave, and the default values seem to work fine.

As we lower the number of bits, rounding becomes problematic as it adds bias to the number system. Numpy attempt to reduce the effects of bias by rounding to even instead of rounding to infinity. Recent results ("Suyog Gupta, Ankur Agrawal, Kailash Gopalakrishnan, Pritish Narayanan; Deep Learning with Limited Numerical Precision [https://arxiv.org/abs/1502.02551]) suggested using stochastic rounding, which uses the fracional part of the number as a probability to round up or down. We can turn on stochastic rounding in some quantizers by setting use_stochastic_rounding to True in quantized_bits, binary, ternary, quantized_relu and quantized_tanh, quantized_po2, and quantized_relu_po2. Please note that if one is considering an efficient hardware or software implementation, we should avoid setting this flag to True in activations as it may affect the efficiency of an implementation. In addition, as mentioned before, we already set this flag to True in some quantized layers when the quantizers are used as weights/biases.

The parameters bits specify the number of bits for the quantization, and integer specifies how many bits of bits are to the left of the decimal point. Finally, our experience in training networks with QSeparableConv2D, it is advisable to allocate more bits between the depthwise and the pointwise quantization, and both quantized_bits and quantized_tanh should use symmetric versions for weights and bias in order to properly converge and eliminate the bias.

We have substantially improved stochastic rounding implementation in QKeras $>= 0.7$, and added a symbolic way to compute alpha in binary, stochastic_binary, ternary, stochastic_ternary, bernoulli and quantized_bits. Right now, a scale and the threshold (for ternary and stochastic_ternary) can be computed independently of the distribution of the inputs, which is required when using these quantizers in weights.

The main problem in using very small bit widths in large deep learning networks stem from the fact that weights are initialized with variance roughly $\propto \sqrt{1/\tt{fanin}}$, but during the training the variance shifts outwards. If the smallest quantization representation (threshold in ternary networks) is smaller than $\sqrt{1/\tt{fanin}}$, we run the risk of having the weights stuck at 0 during training. So, the weights need to dynamically adjust to the variance shift from initialization to the final training. This can be done by scaling the quantization.

Scale is computed using the formula $\sum(\tt{dot}(Q,x))/\sum(\tt{dot}(Q,Q))$ which is described in several papers, including Mohammad Rastegari, Vicente Ordonez, Joseph Redmon, Ali Farhadi "XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks" [https://arxiv.org/abs/1603.05279]. Scale computation is computed for each output channel, making our implementation sometimes behaving like a mini-batch normalization adjustment.

For ternary and stochastic_ternary, we iterate between scale computation and threshold computation, as presented in K. Hwang and W. Sung, "Fixed-point feedforward deep neural network design using weights +1, 0, and −1," 2014 IEEE Workshop on Signal Processing Systems (SiPS), Belfast, 2014, pp. 1-6 which makes the search for threshold and scale tolerant to different input distributions. This is especially important when we need to consider that the threshold shifts depending on the input distribution, affecting the scale as well, as pointed out by Fengfu Li, Bo Zhang, Bin Liu, "Ternary Weight Networks" [https://arxiv.org/abs/1605.04711].

When computing the scale in these quantizers, if alpha="auto", we compute the scale as a floating point number. If alpha="auto_po2", we enforce the scale to be a power of 2, meaning that an actual hardware or software implementation can be performed by just shifting the result of the convolution or dense layer to the right or left by checking the sign of the scale (positive shifts left, negative shifts right), and taking the log2 of the scale. This behavior is compatible with shared exponent approaches, as it performs a shift adjustment to the channel.

We have implemented a method for each quantizer called _set_trainable_parameter that instructs QKeras to set best options when this quantizer is used as a weight or for gamma, variance and beta in QBatchNormalization, so in principle, users should not worry about this.

The following pictures show the behavior of binary vs stochastic rounding in binary vs stochastic_binary (Figure 1) and ternary vs stochastic rounding in ternary and stochastic_ternary (Figure 2). We generated a normally distributed input with mean 0.0 and standard deviation of 0.02, ordered the data, and ran the quantizer 1,000 times, averaging the result for each case. Note that because of scale, the output does not range from $[-1.0, +1.0]$, but from $[-\tt{scale}, +\tt{scale}]$.

    Figure 1: Behavior of binary quantizers
    Figure 2: Behavior of ternary quantizers

Using QKeras

QKeras works by tagging all variables and weights/bias created by Keras as well as output of arithmetic layers by quantized functions. Quantized functions can be instantiated directly in QDense/QConv2D/QSeparableConv2D functions, and they can be passed to QActivation, which act as a merged quantization and activation function.

In order to successfully quantize a model, users need to replace layers that create variables (trainable or not) (Dense, Conv2D, etc) by their equivalent ones in Qkeras (QDense, QConv2D, etc), and any layers that perform math operations need to be quantized afterwards.

Quantized values are clipped between their maximum and minimum quantized representation (which may be different than $[-1.0, 1.0]$), although for po2 type of quantizers, we still recommend the users to specify the parameter for max_value.

An example of a very simple network is given below in Keras.


In [1]:
import six
import numpy as np
import tensorflow.compat.v2 as tf

from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical


/usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters

In [2]:
def CreateModel(shape, nb_classes):
    x = x_in = Input(shape)
    x = Conv2D(18, (3, 3), name="conv2d_1")(x)
    x = Activation("relu", name="act_1")(x)
    x = Conv2D(32, (3, 3), name="conv2d_2")(x)
    x = Activation("relu", name="act_2")(x)
    x = Flatten(name="flatten")(x)
    x = Dense(nb_classes, name="dense")(x)
    x = Activation("softmax", name="softmax")(x)
    
    model = Model(inputs=x_in, outputs=x)

    return model

In [3]:
def get_data():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = x_train.reshape(x_train.shape + (1,)).astype("float32")
    x_test = x_test.reshape(x_test.shape + (1,)).astype("float32")

    x_train /= 256.0
    x_test /= 256.0

    x_mean = np.mean(x_train, axis=0)

    x_train -= x_mean
    x_test -= x_mean

    nb_classes = np.max(y_train)+1
    y_train = to_categorical(y_train, nb_classes)
    y_test = to_categorical(y_test, nb_classes)

    return (x_train, y_train), (x_test, y_test)

In [4]:
(x_train, y_train), (x_test, y_test) = get_data()

model = CreateModel(x_train.shape[1:], y_train.shape[-1])

In [5]:
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

In [6]:
model.fit(x_train, y_train, epochs=3, batch_size=128, validation_data=(x_test, y_test), verbose=True)


Train on 60000 samples, validate on 10000 samples
Epoch 1/3
60000/60000 [==============================] - 7s 122us/sample - loss: 0.2103 - accuracy: 0.9393 - val_loss: 0.0685 - val_accuracy: 0.9781
Epoch 2/3
60000/60000 [==============================] - 6s 108us/sample - loss: 0.0642 - accuracy: 0.9808 - val_loss: 0.0575 - val_accuracy: 0.9817
Epoch 3/3
60000/60000 [==============================] - 7s 112us/sample - loss: 0.0457 - accuracy: 0.9860 - val_loss: 0.0502 - val_accuracy: 0.9844
Out[6]:
<tensorflow.python.keras.callbacks.History at 0x7fd5d70ccbd0>

Great! it is relatively easy to create a network that converges in MNIST with very high test accuracy. The reader should note that we named all the layers as it will make it easier to automatically convert the network by name.


In [7]:
model.summary()


Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 26, 26, 18)        180       
_________________________________________________________________
act_1 (Activation)           (None, 26, 26, 18)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 24, 24, 32)        5216      
_________________________________________________________________
act_2 (Activation)           (None, 24, 24, 32)        0         
_________________________________________________________________
flatten (Flatten)            (None, 18432)             0         
_________________________________________________________________
dense (Dense)                (None, 10)                184330    
_________________________________________________________________
softmax (Activation)         (None, 10)                0         
=================================================================
Total params: 189,726
Trainable params: 189,726
Non-trainable params: 0
_________________________________________________________________

The corresponding quantized network is presented below.


In [8]:
from qkeras import *

def CreateQModel(shape, nb_classes):
    x = x_in = Input(shape)
    x = QConv2D(18, (3, 3),
        kernel_quantizer="stochastic_ternary", 
        bias_quantizer="quantized_po2(4)",
        name="conv2d_1")(x)
    x = QActivation("quantized_relu(2)", name="act_1")(x)
    x = QConv2D(32, (3, 3), 
        kernel_quantizer="stochastic_ternary", 
        bias_quantizer="quantized_po2(4)",
        name="conv2d_2")(x)
    x = QActivation("quantized_relu(2)", name="act_2")(x)
    x = Flatten(name="flatten")(x)
    x = QDense(nb_classes,
        kernel_quantizer="quantized_bits(3,0,1)",
        bias_quantizer="quantized_bits(3)",
        name="dense")(x)
    x = Activation("softmax", name="softmax")(x)
    
    model = Model(inputs=x_in, outputs=x)
    
    return model

In [11]:
qmodel = CreateQModel(x_train.shape[1:], y_train.shape[-1])


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-11-0e1b64e5134f> in <module>()
----> 1 qmodel = CreateQModel(x_train.shape[1:], y_train.shape[-1])

<ipython-input-8-0be1258e2665> in CreateQModel(shape, nb_classes)
      6         kernel_quantizer="stochastic_ternary",
      7         bias_quantizer="quantized_po2(4)",
----> 8         name="conv2d_1")(x)
      9     x = QActivation("quantized_relu(2)", name="act_1")(x)
     10     x = QConv2D(32, (3, 3), 

/usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/keras/engine/base_layer.pyc in __call__(self, inputs, *args, **kwargs)
    771                     not base_layer_utils.is_in_eager_or_tf_function()):
    772                   with auto_control_deps.AutomaticControlDependencies() as acd:
--> 773                     outputs = call_fn(cast_inputs, *args, **kwargs)
    774                     # Wrap Tensors in `outputs` in `tf.identity` to avoid
    775                     # circular dependencies.

/usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/autograph/impl/api.pyc in wrapper(*args, **kwargs)
    235       except Exception as e:  # pylint:disable=broad-except
    236         if hasattr(e, 'ag_error_metadata'):
--> 237           raise e.ag_error_metadata.to_exception(e)
    238         else:
    239           raise

TypeError: in converted code:

    build/bdist.linux-x86_64/egg/qkeras/qlayers.py:1150 call  *
        outputs = tf.keras.backend.conv2d(
    /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/keras/backend.py:4889 conv2d
        data_format=tf_data_format)
    /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/ops/nn_ops.py:899 convolution
        name=name)
    /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/ops/nn_ops.py:1010 convolution_internal
        name=name)
    /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/ops/gen_nn_ops.py:969 conv2d
        data_format=data_format, dilations=dilations, name=name)
    /usr/local/google/home/hzhuang/anaconda2/lib/python2.7/site-packages/tensorflow_core/python/framework/op_def_library.py:477 _apply_op_helper
        repr(values), type(values).__name__, err))

    TypeError: Expected float32 passed to parameter 'filter' of op 'Conv2D', got <qkeras.qlayers.stochastic_ternary object at 0x7fd5c427c290> of type 'stochastic_ternary' instead. Error: Expected float32, got <qkeras.qlayers.stochastic_ternary object at 0x7fd5c427c290> of type 'stochastic_ternary' instead.

In [10]:
from tensorflow.keras.optimizers import Adam

qmodel.compile(
    loss="categorical_crossentropy",
    optimizer=Adam(0.0005),
    metrics=["accuracy"])


---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-10-6a6f1317dea3> in <module>()
      1 from tensorflow.keras.optimizers import Adam
      2 
----> 3 qmodel.compile(
      4     loss="categorical_crossentropy",
      5     optimizer=Adam(0.0005),

NameError: name 'qmodel' is not defined

In [44]:
qmodel.fit(x_train, y_train, epochs=10, batch_size=128, validation_data=(x_test, y_test), verbose=True)


Train on 60000 samples, validate on 10000 samples
Epoch 1/10
60000/60000 [==============================] - 52s 869us/sample - loss: 0.5034 - accuracy: 0.8428 - val_loss: 0.2422 - val_accuracy: 0.9276
Epoch 2/10
60000/60000 [==============================] - 49s 813us/sample - loss: 0.2080 - accuracy: 0.9371 - val_loss: 0.1981 - val_accuracy: 0.9415
Epoch 3/10
60000/60000 [==============================] - 50s 832us/sample - loss: 0.1703 - accuracy: 0.9503 - val_loss: 0.1454 - val_accuracy: 0.9571
Epoch 4/10
60000/60000 [==============================] - 49s 813us/sample - loss: 0.1448 - accuracy: 0.9573 - val_loss: 0.1296 - val_accuracy: 0.9616
Epoch 5/10
60000/60000 [==============================] - 48s 806us/sample - loss: 0.1225 - accuracy: 0.9639 - val_loss: 0.1267 - val_accuracy: 0.9636
Epoch 6/10
60000/60000 [==============================] - 49s 818us/sample - loss: 0.1074 - accuracy: 0.9686 - val_loss: 0.1092 - val_accuracy: 0.9662
Epoch 7/10
60000/60000 [==============================] - 48s 801us/sample - loss: 0.1020 - accuracy: 0.9698 - val_loss: 0.1041 - val_accuracy: 0.9718
Epoch 8/10
60000/60000 [==============================] - 50s 825us/sample - loss: 0.0979 - accuracy: 0.9712 - val_loss: 0.1175 - val_accuracy: 0.9662
Epoch 9/10
60000/60000 [==============================] - 49s 823us/sample - loss: 0.0910 - accuracy: 0.9726 - val_loss: 0.1015 - val_accuracy: 0.9711
Epoch 10/10
60000/60000 [==============================] - 50s 836us/sample - loss: 0.0845 - accuracy: 0.9755 - val_loss: 0.1044 - val_accuracy: 0.9710
Out[44]:
<tensorflow.python.keras.callbacks.History at 0x14d1fb0b8>

You should note that we had to lower the learning rate and train the network for longer time. On the other hand, the network should not involve in any multiplications in the convolution layers, and very small multipliers in the dense layers.

Please note that the last Activation was not changed to QActivation as during inference we usually perform the operation argmax on the result instead of softmax.

It seems it is a lot of code to write besides the main network, but in fact, this additional code is only specifying the sizes of the weights and the sizes of the outputs in the case of the activations. Right now, we do not have a way to extract this information from the network structure or problem we are trying to solve, and if we quantize too much a layer, we may end up not been able to recover from that later on.

Converting a Model Automatically

In addition to the drop-in replacement of Keras functions, we have written the following function to assist anyone who wants to quantize a network.

model_quantize(model, quantizer_config, activation_bits, custom_objects=None, transfer_weights=False)

This function converts an non-quantized model (such as the one from model in the previous example) into a quantized version, by applying a configuration specified by the dictionary quantizer_config, and activation_bits specified for unamed activation functions, with this parameter probably being removed in future versions.

The parameter custom_objects specifies object dictionary unknown to Keras, required when you copy a model with lambda layers, or customized layer functions, for example, and if transfer_weights is True, the returned model will have as initial weights the weights from the original model, instead of using random initial weights.

The dictionary specified in quantizer_config can be indexed by a layer name or layer class name. In the example below, conv2d_1 corresponds to the first convolutional layer of the example, while QConv2D corresponds to the default behavior of two dimensional convolutional layers. The reader should note that right now we recommend using QActivation with a dictionary to avoid the conversion of activations such as softmax and linear. In addition, although we could use activation field in the layers, we do not recommend that.

{ "conv2d_1": { "kernel_quantizer": "stochastic_ternary", "bias_quantizer": "quantized_po2(4)" }, "QConv2D": { "kernel_quantizer": "stochastic_ternary", "bias_quantizer": "quantized_po2(4)" }, "QDense": { "kernel_quantizer": "quantized_bits(3,0,1)", "bias_quantizer": "quantized_bits(3)" }, "act_1": "quantized_relu(2)", "QActivation": { "relu": "quantized_relu(2)" } }

In the following example, we will quantize the model using a different strategy.


In [73]:
config = {
  "conv2d_1": {
      "kernel_quantizer": "stochastic_binary",
      "bias_quantizer": "quantized_po2(4)"
  },
  "QConv2D": {
      "kernel_quantizer": "stochastic_ternary",
      "bias_quantizer": "quantized_po2(4)"
  },
  "QDense": {
      "kernel_quantizer": "quantized_bits(4,0,1)",
      "bias_quantizer": "quantized_bits(4)"
  },
  "QActivation": { "relu": "binary" },
  "act_2": "quantized_relu(3)",
}

In [75]:
from qkeras.utils import model_quantize

qmodel = model_quantize(model, config, 4, transfer_weights=True)

for layer in qmodel.layers:
    if hasattr(layer, "kernel_quantizer"):
        print(layer.name, "kernel:", str(layer.kernel_quantizer_internal), "bias:", str(layer.bias_quantizer_internal))
    elif hasattr(layer, "quantizer"):
        print(layer.name, "quantizer:", str(layer.quantizer))

print()
qmodel.summary()


conv2d_1 kernel: stochastic_binary(alpha=auto_po2) bias: quantized_po2(4)
act_1 quantizer: binary()
conv2d_2 kernel: stochastic_ternary(alpha=auto_po2,threshold=0.33) bias: quantized_po2(4)
act_2 quantizer: quantized_relu(3,0)
dense kernel: quantized_bits(4,0,1,alpha=auto_po2,use_stochastic_rounding=1) bias: quantized_bits(4,0,1,alpha=auto_po2,use_stochastic_rounding=1)

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv2d_1 (QConv2D)           (None, 26, 26, 18)        180       
_________________________________________________________________
act_1 (QActivation)          (None, 26, 26, 18)        0         
_________________________________________________________________
conv2d_2 (QConv2D)           (None, 24, 24, 32)        5216      
_________________________________________________________________
act_2 (QActivation)          (None, 24, 24, 32)        0         
_________________________________________________________________
flatten (Flatten)            (None, 18432)             0         
_________________________________________________________________
dense (QDense)               (None, 10)                184330    
_________________________________________________________________
softmax (Activation)         (None, 10)                0         
=================================================================
Total params: 189,726
Trainable params: 189,726
Non-trainable params: 0
_________________________________________________________________

In [76]:
qmodel.compile(
    loss="categorical_crossentropy",
    optimizer=Adam(0.001),
    metrics=["accuracy"])

In [78]:
qmodel.fit(x_train, y_train, epochs=10, batch_size=128, validation_data=(x_test, y_test), verbose=True)


Train on 60000 samples, validate on 10000 samples
Epoch 1/10
60000/60000 [==============================] - 43s 722us/sample - loss: 0.0796 - accuracy: 0.9757 - val_loss: 0.0968 - val_accuracy: 0.9725
Epoch 2/10
60000/60000 [==============================] - 44s 726us/sample - loss: 0.0717 - accuracy: 0.9776 - val_loss: 0.0853 - val_accuracy: 0.9747
Epoch 3/10
60000/60000 [==============================] - 45s 748us/sample - loss: 0.0690 - accuracy: 0.9789 - val_loss: 0.0878 - val_accuracy: 0.9740
Epoch 4/10
60000/60000 [==============================] - 44s 733us/sample - loss: 0.0636 - accuracy: 0.9800 - val_loss: 0.0816 - val_accuracy: 0.9753
Epoch 5/10
60000/60000 [==============================] - 45s 758us/sample - loss: 0.0572 - accuracy: 0.9822 - val_loss: 0.0861 - val_accuracy: 0.9753
Epoch 6/10
60000/60000 [==============================] - 44s 734us/sample - loss: 0.0565 - accuracy: 0.9819 - val_loss: 0.0819 - val_accuracy: 0.9763
Epoch 7/10
60000/60000 [==============================] - 46s 765us/sample - loss: 0.0489 - accuracy: 0.9842 - val_loss: 0.0859 - val_accuracy: 0.9758
Epoch 8/10
60000/60000 [==============================] - 47s 783us/sample - loss: 0.0485 - accuracy: 0.9845 - val_loss: 0.0889 - val_accuracy: 0.9771
Epoch 9/10
60000/60000 [==============================] - 44s 737us/sample - loss: 0.0477 - accuracy: 0.9850 - val_loss: 0.0729 - val_accuracy: 0.9791
Epoch 10/10
60000/60000 [==============================] - 45s 742us/sample - loss: 0.0484 - accuracy: 0.9843 - val_loss: 0.0796 - val_accuracy: 0.9780
Out[78]:
<tensorflow.python.keras.callbacks.History at 0x16b144828>

in addition to model_quantize, QKeras offers the additional utility functions.

BinaryToThermometer(x, classes, value_range, with_residue=False, merge_with_channels, use_two_hot_encoding=False)

This function converts a dense binary encoding of inputs to one-hot (with scales).

Given input matrix x with values (for example) 0, 1, 2, 3, 4, 5, 6, 7, create a number of classes as follows:

If classes=2, value_range=8, with_residue=0, a true one-hot representation is created, and the remaining bits are truncated, using one bit representation.

0 - [1,0] 1 - [1,0] 2 - [1,0] 3 - [1,0] 4 - [0,1] 5 - [0,1] 6 - [0,1] 7 - [0,1]

If classes=2, value_range=8, with_residue=1, the residue is added to the one-hot class, and the class will use 2 bits (for the remainder) + 1 bit (for the one hot)

0 - [1,0] 1 - [1.25,0] 2 - [1.5,0] 3 - [1.75,0] 4 - [0,1] 5 - [0,1.25] 6 - [0,1.5] 7 - [0,1.75]

The arguments of this functions are as follows:

x: the input vector we want to convert. typically its dimension will be (B,H,W,C) for an image, or (B,T,C) or (B,C) for for a 1D signal, where B=batch, H=height, W=width, C=channels or features, T=time for time series. classes: the number of classes to (or log2(classes) bits) to use of the values. value_range: max(x) - min(x) over all possible x values (e.g. for 8 bits, we would use 256 here). with_residue: if true, we split the value range into two sets and add the decimal fraction of the set to the one-hot representation for partial thermometer representation. merge_with_channels: if True, we will not create a separate dimension for the resulting matrix, but we will merge this dimension with the last dimension. use_two_hot_encoding: if true, we will distribute the weight between the current value and the next one to make sure the numbers will always be < 1.

model_save_quantized_weights(model, filename)

This function saves the quantized weights in the model or writes the quantized weights in the file filename for production, as the weights during training are maintained non-quantized because of training. Typically, you should call this function before productizing the final model. The saved model is compatible with Keras for inference, so for power-of-2 quantization, we will not return (sign, round(log2(weights))), but rather (-1)**sign*2**(round(log2(weights))). We also return a dictionary containing the name of the layer and the quantized weights, and for power-of-2 quantizations, we will return sign and round(log2(weights)) so that other tools can properly process that.

load_qmodel(filepath, custom_objects=None, compile=True)

Load quantized model from Keras's model.save() h5 file, where filepath is the path to the filename, custom_objects is an optional dictionary mapping names (strings) to custom classes or functions to be considered during deserialization, and compile instructs QKeras to compile the model after reading it. If an optimizer was found as part of the saved model, the model is already compiled. Otherwise, the model is uncompiled and a warning will be displayed. When compile is set to False, the compilation is omitted without any warning.

print_model_sparsity(model)

Prints sparsity for the pruned layers in the model.

quantized_model_debug(model, X_test, plot=False)

Debugs and plots model weights and activations. It is usually useful to print weights, biases and activations for inputs and outputs when debugging a model. model contains the mixed quantized/unquantized layers for a model. We only print/plot activations and weights/biases for quantized models with the exception of Activation. X_test is the set of inputs we will use to compute activations, and we recommend that the user uses a subsample from the entire set he/she wants to debug. if plot is True, we also plot weights and activations (inputs/outputs) for each layer.

extract_model_operations(model)

As each operation depends on the quantization method for the weights/bias and on the quantization of the inputs, we estimate which operations are required for each layer of the quantized model. For example, inputs of a QDense layer are quantized using quantized_relu_po2 and weights are quantized using quantized_bits, the matrix multiplication can be implemented as a barrel shifter + accumulator without multiplication operations. Right now, we return for each layer one of the following operations: mult, barrel, mux, adder, xor, and the sizes of the operator.

We are currently refactoring this function and it may be substantially changed in the future.

print_qstats(model)

Prints statistics of number of operations per operation type and layer so that user can see how big the model is. This function utilizes extract_model_operations.

An example of the output is presented below.

`Number of operations in model: conv2d_0_m : 25088 (smult_4_8) conv2d_1_m : 663552 (smult_4_4) conv2d_2_m : 147456 (smult_4_4) dense : 5760 (smult_4_4)

Number of operation types in model: smult_4_4 : 816768 smult_4_8 : 25088`

In this example, smult_4_4 stands for 4x4 bit signed multiplication and smult_4_8 stands for 8x4 signed multiplication.

We are currently refactoring this function and it may be substantially changed in the future.

In the quantized network qmodel, let's print the statistics of the model and weights.


In [79]:
print_qstats(qmodel)


Number of operations in model:
    conv2d_1                      : 109512 (smux_1_8)
    conv2d_2                      : 2985984 (smux_2_1)
    dense                         : 184320 (smult_4_3)

Number of operation types in model:
    smult_4_3                     : 184320
    smux_1_8                      : 109512
    smux_2_1                      : 2985984

Weight profiling:
    conv2d_1_weights               : 162   (1-bit unit)
    conv2d_1_bias                  : 18    (4-bit unit)
    conv2d_2_weights               : 5184  (2-bit unit)
    conv2d_2_bias                  : 32    (4-bit unit)
    dense_weights                  : 184320 (4-bit unit)
    dense_bias                     : 10    (4-bit unit)

In [81]:
from qkeras.utils import quantized_model_debug

quantized_model_debug(qmodel, x_test, plot=False)


input                           -0.5451   0.9960
conv2d_1                        -4.6218   4.0295 ( -1.0000   1.0000) ( -0.5000   0.5000) a(  0.125000   0.500000)
act_1                           -1.0000   1.0000
conv2d_2                       -21.2500  14.2500 ( -1.0000   1.0000) ( -0.2500  -0.1250) a(  0.125000   0.250000)
act_2                            0.0000   0.8750
dense                          -52.1094  39.4062 ( -0.5000   0.3750) ( -0.1250   0.1250) a(  1.000000   1.000000)
softmax                          0.0000   1.0000

Where the values in conv2d_1 -4.6218 4.0295 ( -1.0000 1.0000) ( -0.5000 0.5000) a( 0.125000 0.500000) corresponde to min and max values of the output of the convolution layer, weight ranges (min and max), bias (min and max) and alpha (min and max).