In [ ]:
##### Copyright 2020 Google LLC
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
QKeras is a quantization extension to Keras that provides drop-in replacement for some of the Keras layers, especially the ones that creates parameters and activation layers, and perform arithmetic operations, so that we can quickly create a deep quantized version of Keras network.
According to Tensorflow documentation, Keras is a high-level API to build and train deep learning models. It's used for fast prototyping, advanced research, and production, with three key advantages:
User friendly
Keras has a simple, consistent interface optimized for common use cases. It provides clear and actionable feedback for user errors.
Modular and composable
Keras models are made by connecting configurable building blocks together, with few restrictions.
Easy to extend
Write custom building blocks to express new ideas for research. Create new layers, loss functions, and develop state-of-the-art models.
QKeras is being designed to extend the functionality of Keras using Keras' design principle, i.e. being user friendly, modular and extensible, adding to it being "minimally intrusive" of Keras native functionality.
QKeras has been implemented based on the work of "B.Moons et al. - Minimum Energy Quantized Neural Networks" , Asilomar Conference on Signals, Systems and Computers, 2017 and “Zhou, S. et al. DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients,” but the framework should be easily extensible. The original code from QNN can be found below.
https://github.com/BertMoons/QuantizedNeuralNetworks-Keras-Tensorflow
QKeras extends QNN by providing a richer set of layers (including SeparableConv2D, DepthwiseConv2D, ternary and stochastic ternary quantizations), besides some functions to aid the estimation for the accumulators and conversion between non-quantized to quantized networks. Finally, our main goal is easy of use, so we attempt to make QKeras layers a true drop-in replacement for Keras, so that users can easily exchange non-quantized layers by quantized ones.
The following layers have been implemented in QKeras.
QDense
QConv1D
QConv2D
QDepthwiseConv2D
QSeparableConv2D
(depthwise + pointwise expanded, extended from MobileNet SeparableConv2D implementation)
QActivation
QAveragePooling2D
(in fact, a AveragePooling2D stacked with a QActivation layer for quantization of the result, so this layer does not exist)
QBatchNormalization
QOctaveConv2D
It is worth noting that not all functionality is safe at this time to be used with other high-level operations, such as with layer wrappers. For example, Bidirectional
layer wrappers are used with RNNs. If this is required, we encourage users to use quantization functions invoked as strings instead of the actual functions as a way through this, but we may change that implementation in the future.
QSeparableConv2D
is implemented as a depthwise + pointwise quantized expansions, which is extended from the SeparableConv2D
implementation of MobileNet. With the exception of QBatchNormalization
, if quantizers are not specified, no quantization is applied to the layer and it ends up behaving like the orgininal unquantized layers. On the other hand, QBatchNormalization
has been implemented differently as if the user does not specify any quantizers as parameters, it uses a set up that has worked best when attempting to implement quantization efficiently in hardware and software, i.e. gamma
and variance
with po2 quantizers (as they become shift registers in an implementation, and with further constraining variance po2 quantizer to use quadratic approximation as we take the square root of the variance to obtain the standard deviation), beta
using po2 quantizer to maintain the dynamic range aspect of the center parameter, and mean
remaining unquantized, as it inherits the properties of the previous layer.
Activation has been migrated to QActivation
although it QKeras also recognizes activation parameter used in convolutional and dense layers.
We have improved the setup of quantization as convolution, dense and batch normalization layers now notify the quantizers when the quantizers are used as internal parameters, so the user does not need to worry about setting up options that work best in weights
and bias
like alpha
and use_stochastic_rounding
(although users may override the automatic setup).
Finally, in the current version, we have eliminated the need to set up the range of the quantizers like kernel_range
in QDense
. This is automatically computed internally at this point. Although we kept the parameters for backward compatibility, these parameters will be removed in the future.
Quantizers and activation layers are treated interchangingly in QKeras.
The list of quantizers and its parameters is listed below.
smooth_sigmoid(x)
hard_sigmoid(x)
binary_sigmoid(x)
smooth_tanh(x)
hard_tanh(x)
binary_tanh(x)
quantized_bits(bits=8, integer=0, symmetric=0, keep_negative=1, alpha=None, use_stochastic_rouding=False)(x)
bernoulli(alpha=None, temperature=6.0, use_real_sigmoid=True)(x)
stochastic_ternary(alpha=None, threshold=None, temperature=8.0, use_real_sigmoid=True)(x)
ternary(alpha=None, threshold=None, use_stochastic_rounding=False)(x)
stochastic_binary(alpha=None, temperature=6.0, use_real_sigmoid=True)(x)
binary(use_01=False, alpha=None, use_stochastic_rounding=False)(x)
quantized_relu(bits=8, integer=0, use_sigmoid=0, use_stochastic_rounding=False)(x)
quantized_ulaw(bits=8, integer=0, symmetric=0, u=255.0)(x)
quantized_tanh(bits=8, integer=0, symmetric=0, use_stochastic_rounding=False)(x)
quantized_po2(bits=8, max_value=None, use_stochastic_rounding=False, quadratic_approximation=False)(x)
quantized_relu_po2(bits=8, max_value=None, use_stochastic_rounding=False, quadratic_approximation=False)(x)
The stochastic_*
functions and bernoulli
rely on stochastic versions of the activation functions, so they are best suited for weights and biases. They draw a random number with uniform distribution from sigmoid
of the input x, and result is based on the expected value of the activation function. Please refer to the papers if you want to understand the underlying theory, or the documentation in qkeras/quantizers.py. The parameter temperature
determines how steep the sigmoid function will behave, and the default values seem to work fine.
As we lower the number of bits, rounding becomes problematic as it adds bias to the number system. Numpy attempt to reduce the effects of bias by rounding to even instead of rounding to infinity. Recent results ("Suyog Gupta, Ankur Agrawal, Kailash Gopalakrishnan, Pritish Narayanan; Deep Learning with Limited Numerical Precision [https://arxiv.org/abs/1502.02551]) suggested using stochastic rounding, which uses the fracional part of the number as a probability to round up or down. We can turn on stochastic rounding in some quantizers by setting use_stochastic_rounding
to True
in quantized_bits
, binary
, ternary
, quantized_relu
and quantized_tanh
, quantized_po2
, and quantized_relu_po2
. Please note that if one is considering an efficient hardware or software implementation, we should avoid setting this flag to True
in activations as it may affect the efficiency of an implementation. In addition, as mentioned before, we already set this flag to True
in some quantized layers when the quantizers are used as weights/biases.
The parameters bits
specify the number of bits for the quantization, and integer
specifies how many bits of bits
are to the left of the decimal point. Finally, our experience in training networks with QSeparableConv2D
, it is advisable to allocate more bits between the depthwise and the pointwise quantization, and both quantized_bits
and quantized_tanh
should use symmetric versions for weights and bias in order to properly converge and eliminate the bias.
We have substantially improved stochastic rounding implementation in QKeras $>= 0.7$, and added a symbolic way to compute alpha in binary
, stochastic_binary
, ternary
, stochastic_ternary
, bernoulli
and quantized_bits
. Right now, a scale and the threshold (for ternary and stochastic_ternary) can be computed independently of the distribution of the inputs, which is required when using these quantizers in weights.
The main problem in using very small bit widths in large deep learning networks stem from the fact that weights are initialized with variance roughly $\propto \sqrt{1/\tt{fanin}}$, but during the training the variance shifts outwards. If the smallest quantization representation (threshold in ternary networks) is smaller than $\sqrt{1/\tt{fanin}}$, we run the risk of having the weights stuck at 0 during training. So, the weights need to dynamically adjust to the variance shift from initialization to the final training. This can be done by scaling the quantization.
Scale is computed using the formula $\sum(\tt{dot}(Q,x))/\sum(\tt{dot}(Q,Q))$ which is described in several papers, including Mohammad Rastegari, Vicente Ordonez, Joseph Redmon, Ali Farhadi "XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks" [https://arxiv.org/abs/1603.05279]. Scale computation is computed for each output channel, making our implementation sometimes behaving like a mini-batch normalization adjustment.
For ternary
and stochastic_ternary
, we iterate between scale computation and threshold computation, as presented in K. Hwang and W. Sung, "Fixed-point feedforward deep neural network design using weights +1, 0, and −1," 2014 IEEE Workshop on Signal Processing Systems (SiPS), Belfast, 2014, pp. 1-6 which makes the search for threshold and scale tolerant to different input distributions. This is especially important when we need to consider that the threshold shifts depending on the input distribution, affecting the scale as well, as pointed out by Fengfu Li, Bo Zhang, Bin Liu, "Ternary Weight Networks" [https://arxiv.org/abs/1605.04711].
When computing the scale in these quantizers, if alpha="auto"
, we compute the scale as a floating point number. If alpha="auto_po2"
, we enforce the scale to be a power of 2, meaning that an actual hardware or software implementation can be performed by just shifting the result of the convolution or dense layer to the right or left by checking the sign of the scale (positive shifts left, negative shifts right), and taking the log2 of the scale. This behavior is compatible with shared exponent approaches, as it performs a shift adjustment to the channel.
We have implemented a method for each quantizer called _set_trainable_parameter
that instructs QKeras to set best options when this quantizer is used as a weight or for gamma, variance and beta in QBatchNormalization
, so in principle, users should not worry about this.
The following pictures show the behavior of binary
vs stochastic rounding in binary
vs stochastic_binary
(Figure 1) and ternary
vs stochastic rounding in ternary
and stochastic_ternary
(Figure 2). We generated a normally distributed input with mean 0.0 and standard deviation of 0.02, ordered the data, and ran the quantizer 1,000 times, averaging the result for each case. Note that because of scale, the output does not range from $[-1.0, +1.0]$, but from $[-\tt{scale}, +\tt{scale}]$.
QKeras works by tagging all variables and weights/bias created by Keras as well as output of arithmetic layers by quantized functions. Quantized functions can be instantiated directly in QDense
/QConv2D
/QSeparableConv2D
functions, and they can be passed to QActivation
, which act as a merged quantization and activation function.
In order to successfully quantize a model, users need to replace layers that create variables (trainable or not) (Dense
, Conv2D
, etc) by their equivalent ones in Qkeras (QDense
, QConv2D
, etc), and any layers that perform math operations need to be quantized afterwards.
Quantized values are clipped between their maximum and minimum quantized representation (which may be different than $[-1.0, 1.0]$), although for po2
type of quantizers, we still recommend the users to specify the parameter for max_value
.
An example of a very simple network is given below in Keras.
In [1]:
import six
import numpy as np
import tensorflow.compat.v2 as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
In [2]:
def CreateModel(shape, nb_classes):
x = x_in = Input(shape)
x = Conv2D(18, (3, 3), name="conv2d_1")(x)
x = Activation("relu", name="act_1")(x)
x = Conv2D(32, (3, 3), name="conv2d_2")(x)
x = Activation("relu", name="act_2")(x)
x = Flatten(name="flatten")(x)
x = Dense(nb_classes, name="dense")(x)
x = Activation("softmax", name="softmax")(x)
model = Model(inputs=x_in, outputs=x)
return model
In [3]:
def get_data():
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(x_train.shape + (1,)).astype("float32")
x_test = x_test.reshape(x_test.shape + (1,)).astype("float32")
x_train /= 256.0
x_test /= 256.0
x_mean = np.mean(x_train, axis=0)
x_train -= x_mean
x_test -= x_mean
nb_classes = np.max(y_train)+1
y_train = to_categorical(y_train, nb_classes)
y_test = to_categorical(y_test, nb_classes)
return (x_train, y_train), (x_test, y_test)
In [4]:
(x_train, y_train), (x_test, y_test) = get_data()
model = CreateModel(x_train.shape[1:], y_train.shape[-1])
In [5]:
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
In [6]:
model.fit(x_train, y_train, epochs=3, batch_size=128, validation_data=(x_test, y_test), verbose=True)
Out[6]:
Great! it is relatively easy to create a network that converges in MNIST with very high test accuracy. The reader should note that we named all the layers as it will make it easier to automatically convert the network by name.
In [7]:
model.summary()
The corresponding quantized network is presented below.
In [8]:
from qkeras import *
def CreateQModel(shape, nb_classes):
x = x_in = Input(shape)
x = QConv2D(18, (3, 3),
kernel_quantizer="stochastic_ternary",
bias_quantizer="quantized_po2(4)",
name="conv2d_1")(x)
x = QActivation("quantized_relu(2)", name="act_1")(x)
x = QConv2D(32, (3, 3),
kernel_quantizer="stochastic_ternary",
bias_quantizer="quantized_po2(4)",
name="conv2d_2")(x)
x = QActivation("quantized_relu(2)", name="act_2")(x)
x = Flatten(name="flatten")(x)
x = QDense(nb_classes,
kernel_quantizer="quantized_bits(3,0,1)",
bias_quantizer="quantized_bits(3)",
name="dense")(x)
x = Activation("softmax", name="softmax")(x)
model = Model(inputs=x_in, outputs=x)
return model
In [11]:
qmodel = CreateQModel(x_train.shape[1:], y_train.shape[-1])
In [10]:
from tensorflow.keras.optimizers import Adam
qmodel.compile(
loss="categorical_crossentropy",
optimizer=Adam(0.0005),
metrics=["accuracy"])
In [44]:
qmodel.fit(x_train, y_train, epochs=10, batch_size=128, validation_data=(x_test, y_test), verbose=True)
Out[44]:
You should note that we had to lower the learning rate and train the network for longer time. On the other hand, the network should not involve in any multiplications in the convolution layers, and very small multipliers in the dense layers.
Please note that the last Activation
was not changed to QActivation
as during inference we usually perform the operation argmax
on the result instead of softmax
.
It seems it is a lot of code to write besides the main network, but in fact, this additional code is only specifying the sizes of the weights and the sizes of the outputs in the case of the activations. Right now, we do not have a way to extract this information from the network structure or problem we are trying to solve, and if we quantize too much a layer, we may end up not been able to recover from that later on.
In addition to the drop-in replacement of Keras functions, we have written the following function to assist anyone who wants to quantize a network.
model_quantize(model, quantizer_config, activation_bits, custom_objects=None, transfer_weights=False)
This function converts an non-quantized model (such as the one from model
in the previous example) into a quantized version, by applying a configuration specified by the dictionary quantizer_config
, and activation_bits
specified for unamed activation functions, with this parameter probably being removed in future versions.
The parameter custom_objects
specifies object dictionary unknown to Keras, required when you copy a model with lambda layers, or customized layer functions, for example, and if transfer_weights
is True
, the returned model will have as initial weights the weights from the original model, instead of using random initial weights.
The dictionary specified in quantizer_config
can be indexed by a layer name or layer class name. In the example below, conv2d_1 corresponds to the first convolutional layer of the example, while QConv2D corresponds to the default behavior of two dimensional convolutional layers. The reader should note that right now we recommend using QActivation
with a dictionary to avoid the conversion of activations such as softmax
and linear
. In addition, although we could use activation
field in the layers, we do not recommend that.
{
"conv2d_1": {
"kernel_quantizer": "stochastic_ternary",
"bias_quantizer": "quantized_po2(4)"
},
"QConv2D": {
"kernel_quantizer": "stochastic_ternary",
"bias_quantizer": "quantized_po2(4)"
},
"QDense": {
"kernel_quantizer": "quantized_bits(3,0,1)",
"bias_quantizer": "quantized_bits(3)"
},
"act_1": "quantized_relu(2)",
"QActivation": { "relu": "quantized_relu(2)" }
}
In the following example, we will quantize the model using a different strategy.
In [73]:
config = {
"conv2d_1": {
"kernel_quantizer": "stochastic_binary",
"bias_quantizer": "quantized_po2(4)"
},
"QConv2D": {
"kernel_quantizer": "stochastic_ternary",
"bias_quantizer": "quantized_po2(4)"
},
"QDense": {
"kernel_quantizer": "quantized_bits(4,0,1)",
"bias_quantizer": "quantized_bits(4)"
},
"QActivation": { "relu": "binary" },
"act_2": "quantized_relu(3)",
}
In [75]:
from qkeras.utils import model_quantize
qmodel = model_quantize(model, config, 4, transfer_weights=True)
for layer in qmodel.layers:
if hasattr(layer, "kernel_quantizer"):
print(layer.name, "kernel:", str(layer.kernel_quantizer_internal), "bias:", str(layer.bias_quantizer_internal))
elif hasattr(layer, "quantizer"):
print(layer.name, "quantizer:", str(layer.quantizer))
print()
qmodel.summary()
In [76]:
qmodel.compile(
loss="categorical_crossentropy",
optimizer=Adam(0.001),
metrics=["accuracy"])
In [78]:
qmodel.fit(x_train, y_train, epochs=10, batch_size=128, validation_data=(x_test, y_test), verbose=True)
Out[78]:
in addition to model_quantize
, QKeras offers the additional utility functions.
BinaryToThermometer(x, classes, value_range, with_residue=False, merge_with_channels, use_two_hot_encoding=False)
This function converts a dense binary encoding of inputs to one-hot (with scales).
Given input matrix x
with values (for example) 0, 1, 2, 3, 4, 5, 6, 7, create a number of classes as follows:
If classes=2, value_range=8, with_residue=0, a true one-hot representation is created, and the remaining bits are truncated, using one bit representation.
0 - [1,0] 1 - [1,0] 2 - [1,0] 3 - [1,0]
4 - [0,1] 5 - [0,1] 6 - [0,1] 7 - [0,1]
If classes=2, value_range=8, with_residue=1, the residue is added to the one-hot class, and the class will use 2 bits (for the remainder) + 1 bit (for the one hot)
0 - [1,0] 1 - [1.25,0] 2 - [1.5,0] 3 - [1.75,0]
4 - [0,1] 5 - [0,1.25] 6 - [0,1.5] 7 - [0,1.75]
The arguments of this functions are as follows:
x: the input vector we want to convert. typically its dimension will be
(B,H,W,C) for an image, or (B,T,C) or (B,C) for for a 1D signal, where
B=batch, H=height, W=width, C=channels or features, T=time for time
series.
classes: the number of classes to (or log2(classes) bits) to use of the
values.
value_range: max(x) - min(x) over all possible x values (e.g. for 8 bits,
we would use 256 here).
with_residue: if true, we split the value range into two sets and add
the decimal fraction of the set to the one-hot representation for partial
thermometer representation.
merge_with_channels: if True, we will not create a separate dimension
for the resulting matrix, but we will merge this dimension with
the last dimension.
use_two_hot_encoding: if true, we will distribute the weight between
the current value and the next one to make sure the numbers will always
be < 1.
model_save_quantized_weights(model, filename)
This function saves the quantized weights in the model or writes the quantized weights in the file filename
for production, as the weights during training are maintained non-quantized because of training. Typically, you should call this function before productizing the final model. The saved model is compatible with Keras for inference, so for power-of-2 quantization, we will not return (sign, round(log2(weights)))
, but rather (-1)**sign*2**(round(log2(weights)))
. We also return a dictionary containing the name of the layer and the quantized weights, and for power-of-2 quantizations, we will return sign
and round(log2(weights))
so that other tools can properly process that.
load_qmodel(filepath, custom_objects=None, compile=True)
Load quantized model from Keras's model.save() h5 file, where filepath is the path to the filename, custom_objects is an optional dictionary mapping names (strings) to custom classes or functions to be considered during deserialization, and compile instructs QKeras to compile the model after reading it. If an optimizer was found as part of the saved model, the model is already compiled. Otherwise, the model is uncompiled and a warning will be displayed. When compile is set to False
, the compilation is omitted without any warning.
print_model_sparsity(model)
Prints sparsity for the pruned layers in the model.
quantized_model_debug(model, X_test, plot=False)
Debugs and plots model weights and activations. It is usually useful to print weights, biases and activations for inputs and outputs when debugging a model. model contains the mixed quantized/unquantized layers for a model. We only print/plot activations and weights/biases for quantized models with the exception of Activation. X_test is the set of inputs we will use to compute activations, and we recommend that the user uses a subsample from the entire set he/she wants to debug. if plot is True, we also plot weights and activations (inputs/outputs) for each layer.
extract_model_operations(model)
As each operation depends on the quantization method for the weights/bias and on the quantization of the inputs, we estimate which operations are required for each layer of the quantized model. For example, inputs of a QDense
layer are quantized using quantized_relu_po2
and weights are quantized using quantized_bits
, the matrix multiplication can be implemented as a barrel shifter + accumulator without multiplication operations. Right now, we return for each layer one of the following operations: mult
, barrel
, mux
, adder
, xor
, and the sizes of the operator.
We are currently refactoring this function and it may be substantially changed in the future.
print_qstats(model)
Prints statistics of number of operations per operation type and layer so that user can see how big the model is. This function utilizes extract_model_operations
.
An example of the output is presented below.
`Number of operations in model: conv2d_0_m : 25088 (smult_4_8) conv2d_1_m : 663552 (smult_4_4) conv2d_2_m : 147456 (smult_4_4) dense : 5760 (smult_4_4)
Number of operation types in model: smult_4_4 : 816768 smult_4_8 : 25088`
In this example, smult_4_4 stands for 4x4 bit signed multiplication and smult_4_8 stands for 8x4 signed multiplication.
We are currently refactoring this function and it may be substantially changed in the future.
In the quantized network qmodel
, let's print the statistics of the model and weights.
In [79]:
print_qstats(qmodel)
In [81]:
from qkeras.utils import quantized_model_debug
quantized_model_debug(qmodel, x_test, plot=False)
Where the values in conv2d_1 -4.6218 4.0295 ( -1.0000 1.0000) ( -0.5000 0.5000) a( 0.125000 0.500000)
corresponde to min and max values of the output of the convolution layer, weight ranges (min and max), bias (min and max) and alpha (min and max).